Fix unclear session errors

This commit is contained in:
Andre Basche 2023-04-12 19:14:14 +02:00
parent 33454f68b8
commit 970b94bfa7
4 changed files with 47 additions and 36 deletions

View file

@ -9,8 +9,7 @@ from urllib.parse import quote
from yarl import URL from yarl import URL
from pyhon import const from pyhon import const, exceptions
from pyhon.exceptions import HonAuthenticationError
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -51,7 +50,7 @@ class HonAuth:
result += f"{15 * '='} Response {15 * '='}\n{await response.text()}\n{40 * '='}" result += f"{15 * '='} Response {15 * '='}\n{await response.text()}\n{40 * '='}"
_LOGGER.error(result) _LOGGER.error(result)
if fail: if fail:
raise HonAuthenticationError("Can't login") raise exceptions.HonAuthenticationError("Can't login")
async def _load_login(self): async def _load_login(self):
nonce = secrets.token_hex(16) nonce = secrets.token_hex(16)
@ -71,7 +70,11 @@ class HonAuth:
f"{const.AUTH_API}/services/oauth2/authorize/expid_Login?{params}" f"{const.AUTH_API}/services/oauth2/authorize/expid_Login?{params}"
) as response: ) as response:
self._called_urls.append((response.status, response.request_info.url)) self._called_urls.append((response.status, response.request_info.url))
if not (login_url := re.findall("url = '(.+?)'", await response.text())): text = await response.text()
if not (login_url := re.findall("url = '(.+?)'", text)):
if "oauth/done#access_token=" in text:
self._parse_token_data(text)
raise exceptions.HonNoAuthenticationNeeded()
await self._error_logger(response) await self._error_logger(response)
return False return False
async with self._session.get(login_url[0], allow_redirects=False) as redirect1: async with self._session.get(login_url[0], allow_redirects=False) as redirect1:
@ -156,6 +159,14 @@ class HonAuth:
await self._error_logger(response) await self._error_logger(response)
return "" return ""
def _parse_token_data(self, text):
if access_token := re.findall("access_token=(.*?)&", text):
self._access_token = access_token[0]
if refresh_token := re.findall("refresh_token=(.*?)&", text):
self._refresh_token = refresh_token[0]
if id_token := re.findall("id_token=(.*?)&", text):
self._id_token = id_token[0]
async def _get_token(self, url): async def _get_token(self, url):
async with self._session.get(url) as response: async with self._session.get(url) as response:
self._called_urls.append((response.status, response.request_info.url)) self._called_urls.append((response.status, response.request_info.url))
@ -179,26 +190,9 @@ class HonAuth:
if response.status != 200: if response.status != 200:
await self._error_logger(response) await self._error_logger(response)
return False return False
text = await response.text() self._parse_token_data(await response.text())
if access_token := re.findall("access_token=(.*?)&", text):
self._access_token = access_token[0]
if refresh_token := re.findall("refresh_token=(.*?)&", text):
self._refresh_token = refresh_token[0]
if id_token := re.findall("id_token=(.*?)&", text):
self._id_token = id_token[0]
return True return True
async def authorize(self):
if login_site := await self._load_login():
fw_uid, loaded, login_url = login_site
else:
return False
if not (url := await self._login(fw_uid, loaded, login_url)):
return False
if not await self._get_token(url):
return False
return await self._api_auth()
async def _api_auth(self): async def _api_auth(self):
post_headers = {"id-token": self._id_token} post_headers = {"id-token": self._id_token}
data = self._device.get() data = self._device.get()
@ -214,6 +208,20 @@ class HonAuth:
self._cognito_token = json_data["cognitoUser"]["Token"] self._cognito_token = json_data["cognitoUser"]["Token"]
return True return True
async def authenticate(self):
self.clear()
try:
if not (login_site := await self._load_login()):
raise exceptions.HonAuthenticationError("Can't open login page")
if not (url := await self._login(*login_site)):
raise exceptions.HonAuthenticationError("Can't login")
if not await self._get_token(url):
raise exceptions.HonAuthenticationError("Can't get token")
if not await self._api_auth():
raise exceptions.HonAuthenticationError("Can't get api token")
except exceptions.HonNoAuthenticationNeeded:
return
async def refresh(self): async def refresh(self):
params = { params = {
"client_id": const.CLIENT_ID, "client_id": const.CLIENT_ID,
@ -231,3 +239,10 @@ class HonAuth:
self._id_token = data["id_token"] self._id_token = data["id_token"]
self._access_token = data["access_token"] self._access_token = data["access_token"]
return await self._api_auth() return await self._api_auth()
def clear(self):
self._session.cookie_jar.clear_domain(const.AUTH_API.split("/")[-2])
self._cognito_token = ""
self._id_token = ""
self._access_token = ""
self._refresh_token = ""

View file

@ -57,7 +57,6 @@ class HonConnectionHandler(HonBaseConnectionHandler):
raise HonAuthenticationError("An email address must be specified") raise HonAuthenticationError("An email address must be specified")
if not self._password: if not self._password:
raise HonAuthenticationError("A password address must be specified") raise HonAuthenticationError("A password address must be specified")
self._request_headers = {}
@property @property
def device(self): def device(self):
@ -69,16 +68,11 @@ class HonConnectionHandler(HonBaseConnectionHandler):
return self return self
async def _check_headers(self, headers): async def _check_headers(self, headers):
if ( if not (self._auth.cognito_token and self._auth.id_token):
"cognito-token" not in self._request_headers await self._auth.authenticate()
or "id-token" not in self._request_headers headers["cognito-token"] = self._auth.cognito_token
): headers["id-token"] = self._auth.id_token
if await self._auth.authorize(): return self._HEADERS | headers
self._request_headers["cognito-token"] = self._auth.cognito_token
self._request_headers["id-token"] = self._auth.id_token
else:
raise HonAuthenticationError("Can't login")
return self._HEADERS | headers | self._request_headers
@asynccontextmanager @asynccontextmanager
async def _intercept(self, method, *args, loop=0, **kwargs): async def _intercept(self, method, *args, loop=0, **kwargs):
@ -98,8 +92,6 @@ class HonConnectionHandler(HonBaseConnectionHandler):
response.status, response.status,
await response.text(), await response.text(),
) )
self._request_headers = {}
self._session.cookie_jar.clear_domain(const.AUTH_API.split("/")[-2])
await self.create() await self.create()
async with self._intercept( async with self._intercept(
method, *args, loop=loop + 1, **kwargs method, *args, loop=loop + 1, **kwargs

View file

@ -1,2 +1,6 @@
class HonAuthenticationError(Exception): class HonAuthenticationError(Exception):
pass pass
class HonNoAuthenticationNeeded(Exception):
pass

View file

@ -7,7 +7,7 @@ with open("README.md", "r") as f:
setup( setup(
name="pyhOn", name="pyhOn",
version="0.7.2", version="0.7.3",
author="Andre Basche", author="Andre Basche",
description="Control hOn devices with python", description="Control hOn devices with python",
long_description=long_description, long_description=long_description,