From 93bf983eb604d6fffd941d921a1194a2b237ff1f Mon Sep 17 00:00:00 2001 From: Krisjanis Lejejs Date: Mon, 2 Jan 2023 23:56:16 +0200 Subject: [PATCH] Add request sessions, change refresh token expiry time --- ThermiaOnlineAPI/api/ThermiaAPI.py | 55 ++++++++++++++++++++---------- 1 file changed, 37 insertions(+), 18 deletions(-) diff --git a/ThermiaOnlineAPI/api/ThermiaAPI.py b/ThermiaOnlineAPI/api/ThermiaAPI.py index 583da79..a1586cc 100644 --- a/ThermiaOnlineAPI/api/ThermiaAPI.py +++ b/ThermiaOnlineAPI/api/ThermiaAPI.py @@ -2,6 +2,7 @@ from collections import ChainMap from datetime import datetime, timedelta import requests +from requests.adapters import HTTPAdapter, Retry from requests import cookies import json import hashlib @@ -59,6 +60,11 @@ def __init__(self, email, password, api_type): "Content-Type": "application/json", } + self.__session = requests.Session() + retry = Retry(connect=3, backoff_factor=0.5) + adapter = HTTPAdapter(max_retries=retry) + self.__session.mount("https://", adapter) + if api_type not in THERMIA_API_CONFIG_URLS_BY_API_TYPE: raise ValueError("Unknown device type: " + api_type) @@ -71,7 +77,7 @@ def get_devices(self): self.__check_token_validity() url = self.configuration["apiBaseUrl"] + "/api/v1/InstallationsInfo/own" - request = requests.get(url, headers=self.__default_request_headers) + request = self.__session.get(url, headers=self.__default_request_headers) status = request.status_code if status != 200: @@ -97,7 +103,7 @@ def get_device_info(self, device_id: str): self.__check_token_validity() url = self.configuration["apiBaseUrl"] + "/api/v1/installations/" + device_id - request = requests.get(url, headers=self.__default_request_headers) + request = self.__session.get(url, headers=self.__default_request_headers) status = request.status_code if status != 200: @@ -115,7 +121,7 @@ def get_device_status(self, device_id: str): + device_id + "/status" ) - request = requests.get(url, headers=self.__default_request_headers) + request = self.__session.get(url, headers=self.__default_request_headers) status = request.status_code if status != 200: @@ -133,7 +139,7 @@ def get_all_alarms(self, device_id: str): + str(device_id) + "/events?onlyActiveAlarms=false" ) - request = requests.get(url, headers=self.__default_request_headers) + request = self.__session.get(url, headers=self.__default_request_headers) status = request.status_code if status != 200: @@ -150,7 +156,7 @@ def get_historical_data_registers(self, device_id: str): + "/api/v1/DataHistory/installation/" + str(device_id) ) - request = requests.get(url, headers=self.__default_request_headers) + request = self.__session.get(url, headers=self.__default_request_headers) status = request.status_code if status != 200: @@ -175,7 +181,7 @@ def get_historical_data( + "&periodEnd=" + end_date_str ) - request = requests.get(url, headers=self.__default_request_headers) + request = self.__session.get(url, headers=self.__default_request_headers) status = request.status_code if status != 200: @@ -196,7 +202,7 @@ def get_all_available_groups(self, installation_profile_id: int): + "/groups" ) - request = requests.get(url, headers=self.__default_request_headers) + request = self.__session.get(url, headers=self.__default_request_headers) status = request.status_code if status != 200: @@ -408,7 +414,7 @@ def __get_register_group(self, device_id: str, register_group: str) -> list: + "/Groups/" + register_group ) - request = requests.get(url, headers=self.__default_request_headers) + request = self.__session.get(url, headers=self.__default_request_headers) status = request.status_code if status != 200: @@ -439,7 +445,9 @@ def __set_register_value( "clientUuid": "api-client-uuid", } - request = requests.post(url, headers=self.__default_request_headers, json=body) + request = self.__session.post( + url, headers=self.__default_request_headers, json=body + ) status = request.status_code if status != 200: @@ -451,7 +459,7 @@ def __set_register_value( ) def __fetch_configuration(self): - request = requests.get(self.__api_config_url) + request = self.__session.get(self.__api_config_url) status = request.status_code if status != 200: @@ -469,20 +477,24 @@ def __authenticate_refresh_token(self) -> Optional[str]: "grant_type": "refresh_token", } - request_token = requests.post( + request_token = self.__session.post( AZURE_AUTH_GET_TOKEN_URL, headers=azure_auth_request_headers, data=request_token__data, ) if request_token.status_code != 200: + self.__refresh_token = None + self.__refresh_token_valid_to = None + error_text = ( "Reauthentication request failed with previous refresh token. Status: " + str(request_token.status_code) + ", Response: " + request_token.text ) - _LOGGER.error(error_text) + _LOGGER.info(error_text) + return None return request_token.text @@ -498,6 +510,9 @@ def __authenticate(self) -> bool: request_token_text = self.__authenticate_refresh_token() if request_token_text is None: # New token, or refresh failed + self.__token = None + self.__token_valid_to = None + code_challenge = utils.generate_challenge(43) request_auth__data = { @@ -514,7 +529,9 @@ def __authenticate(self) -> bool: "code_challenge_method": "S256", } - request_auth = requests.get(AZURE_AUTH_AUTHORIZE_URL, request_auth__data) + request_auth = self.__session.get( + AZURE_AUTH_AUTHORIZE_URL, data=request_auth__data + ) state_code = "" csrf_token = "" @@ -549,7 +566,7 @@ def __authenticate(self) -> bool: "p": "B2C_1A_SignUpOrSigninOnline", } - request_self_asserted = requests.post( + request_self_asserted = self.__session.post( AZURE_SELF_ASSERTED_URL, cookies=request_auth.cookies, data=request_self_asserted__data, @@ -582,7 +599,7 @@ def __authenticate(self) -> bool: "p": "B2C_1A_SignUpOrSigninOnline", } - request_confirmed = requests.get( + request_confirmed = self.__session.get( AZURE_AUTH_CONFIRM_URL, cookies=request_confirmed__cookies, params=request_confirmed__params, @@ -597,7 +614,7 @@ def __authenticate(self) -> bool: "grant_type": "authorization_code", } - request_token = requests.post( + request_token = self.__session.post( AZURE_AUTH_GET_TOKEN_URL, headers=azure_auth_request_headers, data=request_token__data, @@ -620,9 +637,9 @@ def __authenticate(self) -> bool: self.__token = token_data["access_token"] self.__token_valid_to = token_data["expires_on"] - # refresh token valid for 24h, but we refresh it every 12h for safety + # refresh token valid for 24h (maybe), but we refresh it every 6h for safety self.__refresh_token_valid_to = ( - datetime.now() + timedelta(hours=12) + datetime.now() + timedelta(hours=6) ).timestamp() self.__refresh_token = token_data.get("refresh_token") @@ -639,6 +656,8 @@ def __check_token_validity(self): if ( self.__token_valid_to is None or self.__token_valid_to < datetime.now().timestamp() + or self.__refresh_token_valid_to is None + or self.__refresh_token_valid_to < datetime.now().timestamp() ): _LOGGER.info("Token expired, re-authenticating.") self.authenticated = self.__authenticate()