Skip to content

Commit

Permalink
Re-login if refreshing failed
Browse files Browse the repository at this point in the history
  • Loading branch information
Prior99 committed Sep 22, 2024
1 parent 244a2e8 commit 751b452
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 36 deletions.
87 changes: 63 additions & 24 deletions myskoda/authorization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Handles authorization to the MySkoda API."""

from asyncio import Lock
import base64
import hashlib
import json
Expand Down Expand Up @@ -58,6 +59,9 @@ class IDKAuthorizationCode(BaseModel):
id_token: str


refresh_token_lock = Lock()


class IDKSession(BaseModel):
"""Stores the JWT tokens relevant for a session at the IDK server.
Expand All @@ -68,30 +72,57 @@ class IDKSession(BaseModel):
refresh_token: str = Field(None, refreshToken="accessToken")
id_token: str = Field(None, idToken="accessToken")

async def perform_refresh(self, session: ClientSession, attempt: int = 0) -> None:
async def perform_refresh(self, session: ClientSession, username: str, password: str) -> None:
"""Refresh the authorization token.
This will consume the `refresh_token` and exchange it for a new set of tokens.
"""
json_data = {"token": self.refresh_token}
async with session.post(
f"{BASE_URL_SKODA}/api/v1/authentication/refresh-token?tokenType=CONNECT",
json=json_data,
) as response:
try:
if not response.ok:
raise InvalidStatusError(response.status) # noqa: TRY301
data = json.loads(await response.text())
self.access_token = data.get("accessToken")
self.refresh_token = data.get("refreshToken")
self.id_token = data.get("idToken")
except Exception:
if attempt >= MAX_RETRIES:
raise
_LOGGER.warning("Retrying failed request to refresh token.")
await self.perform_refresh(session, attempt=attempt + 1)

async def get_access_token(self, session: ClientSession) -> str:

async def perform_request() -> bool:
meta = jwt.decode(
self.access_token, options={"verify_signature": False}
)
expiry = datetime.fromtimestamp(cast(float, meta.get("exp")), tz=UTC)
if datetime.now(tz=UTC) + timedelta(minutes=10) < expiry:
return True
async with session.post(
f"{BASE_URL_SKODA}/api/v1/authentication/refresh-token?tokenType=CONNECT",
json=json_data,
) as response:
try:
if not response.ok:
raise InvalidStatusError(response.status) # noqa: TRY301
data = json.loads(await response.text())
self.access_token = data.get("accessToken")
self.refresh_token = data.get("refreshToken")
self.id_token = data.get("idToken")
except Exception:
return False
else:
return True

async with refresh_token_lock:
attempts = 0
while not await perform_request():
if attempts >= MAX_RETRIES:
_LOGGER.error(
"Refreshing token failed after %d attempts.", MAX_RETRIES
)
_LOGGER.info("Trying to recover by logging in again...")
try:
tokens = await idk_authorize(session, username, password)
except Exception:
_LOGGER.exception("Failed to login.")
else:
self.access_token = tokens.access_token
self.refresh_token = tokens.refresh_token
self.id_token = tokens.id_token
_LOGGER.info("Successfully recovered by logging in.")
_LOGGER.warning("Retrying failed request to refresh token. Retrying...")
attempts = attempts + 1

async def get_access_token(self, session: ClientSession, username: str, password: str) -> str:
"""Get the access token.
Use this method instead of using `access_token` directly. It will automatically
Expand All @@ -101,7 +132,7 @@ async def get_access_token(self, session: ClientSession) -> str:
expiry = datetime.fromtimestamp(cast(float, meta.get("exp")), tz=UTC)
if datetime.now(tz=UTC) + timedelta(minutes=10) > expiry:
_LOGGER.info("Refreshing IDK access token")
await self.perform_refresh(session)
await self.perform_refresh(session, username, password)
return self.access_token


Expand Down Expand Up @@ -168,7 +199,9 @@ async def _initial_oidc_authorize(
"code_challenge_method": "s256",
"prompt": "login",
}
async with session.get(f"{BASE_URL_IDENT}/oidc/v1/authorize", params=params) as response:
async with session.get(
f"{BASE_URL_IDENT}/oidc/v1/authorize", params=params
) as response:
data = _extract_states_from_website(await response.text())
return IDKCredentials(data, email, password)

Expand Down Expand Up @@ -258,14 +291,18 @@ async def _exchange_auth_code_for_idk_session(
return IDKSession(**await response.json())


async def idk_authorize(session: ClientSession, email: str, password: str) -> IDKSession:
async def idk_authorize(
session: ClientSession, email: str, password: str
) -> IDKSession:
"""Perform the full login process.
Must be called before any other methods on the class can be called.
"""
# Generate a random string for the OAUTH2 PKCE challenge.
# (https://www.oauth.com/oauth2-servers/pkce/authorization-request/)
verifier = "".join(random.choices(string.ascii_uppercase + string.digits, k=16)) # noqa: S311
verifier = "".join(
random.choices(string.ascii_uppercase + string.digits, k=16)
) # noqa: S311

# Call the initial OIDC (OpenID Connect) authorization, giving us the initial SSO information.
# The full flow is explain a little bit here:
Expand All @@ -281,7 +318,9 @@ async def idk_authorize(session: ClientSession, email: str, password: str) -> ID
authentication = await _enter_password(session, login_meta)

# Exchange the token for access and refresh tokens (JWT format).
return await _exchange_auth_code_for_idk_session(session, authentication.code, verifier)
return await _exchange_auth_code_for_idk_session(
session, authentication.code, verifier
)


class AuthorizationError(Exception):
Expand Down
36 changes: 27 additions & 9 deletions myskoda/mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ class Mqtt:
should_reconnect: bool
is_connected: bool

def __init__(self, api: RestApi, ssl_context: ssl.SSLContext | None = None) -> None: # noqa: D107
def __init__(
self, api: RestApi, ssl_context: ssl.SSLContext | None = None
) -> None: # noqa: D107
self.api = api
self._callbacks = []
self._operation_listeners = []
Expand All @@ -87,7 +89,9 @@ async def perform_connect() -> bool:
if self.is_connected:
return True

_LOGGER.debug(f"Connecting to MQTT on {MQTT_BROKER_HOST}:{MQTT_BROKER_PORT}...")
_LOGGER.debug(
f"Connecting to MQTT on {MQTT_BROKER_HOST}:{MQTT_BROKER_PORT}..."
)

if not self.user:
self.user = await self.api.get_user()
Expand All @@ -109,9 +113,13 @@ async def perform_connect() -> bool:
self.client.tls_set_context(context=ssl.create_default_context())
self.client.username_pw_set(
self.user.id,
await self.api.idk_session.get_access_token(self.api.session),
await self.api.idk_session.get_access_token(
self.api.session, self.api.email, self.api.password
),
)
self.client.connect_async(
MQTT_BROKER_HOST, MQTT_BROKER_PORT, MQTT_KEEPALIVE
)
self.client.connect_async(MQTT_BROKER_HOST, MQTT_BROKER_PORT, MQTT_KEEPALIVE)

await self._wait_for_connection()
except FailedToConnectError:
Expand Down Expand Up @@ -147,7 +155,9 @@ def _wait_for_connection(self) -> Future[None]:

return future

def wait_for_operation(self, operation_name: OperationName) -> Future[OperationRequest]:
def wait_for_operation(
self, operation_name: OperationName
) -> Future[OperationRequest]:
"""Wait until the next operation of the specified type completes."""
_LOGGER.debug("Waiting for operation %s complete.", operation_name)
future: Future[OperationRequest] = get_event_loop().create_future()
Expand All @@ -156,7 +166,9 @@ def wait_for_operation(self, operation_name: OperationName) -> Future[OperationR

return future

def _on_socket_close(self, client: AsyncioPahoClient, _data: None, _socket: None) -> None:
def _on_socket_close(
self, client: AsyncioPahoClient, _data: None, _socket: None
) -> None:
if client is not self.client:
return
_LOGGER.info("Socket to MQTT broker closed.")
Expand Down Expand Up @@ -241,9 +253,13 @@ def _handle_operation_completed(self, operation: OperationRequest) -> None:
listener.future.set_exception(OperationFailedError(operation))
else:
if operation.status == OperationStatus.COMPLETED_WARNING:
_LOGGER.warning("Operation '%s' completed with warnings.", operation.operation)
_LOGGER.warning(
"Operation '%s' completed with warnings.", operation.operation
)

_LOGGER.debug("Resolving listener for operation '%s'.", operation.operation)
_LOGGER.debug(
"Resolving listener for operation '%s'.", operation.operation
)
listener.future.set_result(operation)

def _handle_operation(self, event: Event) -> None:
Expand Down Expand Up @@ -287,7 +303,9 @@ def _on_message( # noqa: C901
if len(data) == 0:
return

_LOGGER.debug("Message (%s) received for %s on topic %s: %s", event_type, vin, topic, data)
_LOGGER.debug(
"Message (%s) received for %s on topic %s: %s", event_type, vin, topic, data
)

# Messages will contain payload as JSON.
data = json.loads(msg.payload)
Expand Down
14 changes: 11 additions & 3 deletions myskoda/rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ async def authenticate(self, email: str, password: str) -> bool:
Must be called before any other methods on the class can be called.
"""
self.email = email
self.password = password
self.idk_session = await idk_authorize(self.session, email, password)

_LOGGER.info("IDK Authorization was successful.")
Expand Down Expand Up @@ -149,7 +151,9 @@ async def list_vehicles(self) -> list[str]:
return [vehicle["vin"] for vehicle in json["vehicles"]]

async def _headers(self) -> dict[str, str]:
return {"authorization": f"Bearer {await self.idk_session.get_access_token(self.session)}"}
return {
"authorization": f"Bearer {await self.idk_session.get_access_token(self.session, self.email, self.password)}"
}

async def stop_air_conditioning(self, vin: str) -> None:
"""Stop the air conditioning."""
Expand Down Expand Up @@ -183,7 +187,9 @@ async def start_air_conditioning(self, vin: str, temperature: float) -> None:

async def set_target_temperature(self, vin: str, temperature: float) -> None:
"""Set the air conditioning's target temperature in °C."""
_LOGGER.debug("Setting target temperature for vehicle %s to %s", vin, str(temperature))
_LOGGER.debug(
"Setting target temperature for vehicle %s to %s", vin, str(temperature)
)
json_data = {"temperatureValue": str(temperature), "unitInCar": "CELSIUS"}
async with self.session.post(
f"{BASE_URL_SKODA}/api/v2/air-conditioning/{vin}/settings/target-temperature",
Expand Down Expand Up @@ -291,7 +297,9 @@ async def honk_flash(
) -> None:
"""Honk and/or flash."""
positions = await self.get_positions(vin)
position = next(pos for pos in positions.positions if pos.type == PositionType.VEHICLE)
position = next(
pos for pos in positions.positions if pos.type == PositionType.VEHICLE
)
json_data = {
"mode": "HONK_AND_FLASH" if honk else "FLASH",
"vehiclePosition": {
Expand Down

0 comments on commit 751b452

Please sign in to comment.