Skip to content

Commit

Permalink
Simplify GaiaEvents.validate_payload() (#138)
Browse files Browse the repository at this point in the history
  • Loading branch information
vaamb authored Nov 23, 2024
1 parent cc2821f commit 900f747
Showing 1 changed file with 16 additions and 39 deletions.
55 changes: 16 additions & 39 deletions src/ouranos/aggregator/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from anyio import Path as ioPath
from anyio.to_thread import run_sync
from pydantic import TypeAdapter, ValidationError
from pydantic import RootModel, ValidationError
from sqlalchemy import delete, select, update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
Expand Down Expand Up @@ -97,31 +97,10 @@ def __init__(self, *args, **kwargs) -> None:
def validate_payload(
self,
data: PT,
model_cls: Type[gv.BaseModel],
type_: Type[dict] | Type[list],
model_cls: Type[gv.BaseModel] | Type[RootModel],
) -> PT:
if not data:
event = inspect.stack()[1].function.lstrip("on_")
self.logger.error(
f"Encountered an error while validating '{event}' data. Error "
f"msg: Empty data."
)
raise ValidationError("Empty data")
if not isinstance(data, type_):
event = inspect.stack()[1].function.lstrip("on_")
received = type(data)
self.logger.error(
f"Encountered an error while validating '{event}' data. Error "
f"msg: Wrong data format, expected '{type_}', received "
f"'{received}'."
)
raise ValidationError(f"Data is not of the expected type '{type_}'")
try:
if isinstance(data, list):
temp: list[gv.BaseModel] = TypeAdapter(list[model_cls]).validate_python(data)
return [obj.model_dump() for obj in temp]
elif isinstance(data, dict):
return model_cls(**data).model_dump()
return model_cls.model_validate(data).model_dump()
except ValidationError as e:
event = inspect.stack()[1].function.lstrip("on_")
msg_list = [f"{error['loc'][0]}: {error['msg']}" for error in e.errors()]
Expand Down Expand Up @@ -240,8 +219,7 @@ async def on_register_engine(
sid: UUID,
data: gv.EnginePayloadDict,
) -> None:
data: gv.EnginePayloadDict = self.validate_payload(
data, gv.EnginePayload, dict)
data: gv.EnginePayloadDict = self.validate_payload(data, gv.EnginePayload)
engine_uid: str = data["engine_uid"]
remote_addr: str = data["address"]
self.logger.info(
Expand Down Expand Up @@ -327,8 +305,7 @@ async def on_places_list(
) -> None:
self.logger.debug(
f"Received 'places_list' from {engine_uid}.")
payload: gv.PlacesPayloadDict = self.validate_payload(
data, gv.PlacesPayload, dict)
payload: gv.PlacesPayloadDict = self.validate_payload(data, gv.PlacesPayload)
async with db.scoped_session() as session:
for place in payload["data"]:
await Place.update_or_create(
Expand All @@ -353,7 +330,7 @@ async def on_base_info(
async with self.session(sid) as session:
session["init_data"].discard("base_info")
data: list[gv.BaseInfoConfigPayloadDict] = self.validate_payload(
data, gv.BaseInfoConfigPayload, list)
data, RootModel[list[gv.BaseInfoConfigPayload]])
ecosystems_in_config: list[str] = []
ecosystems_status: list[dict[str, str]] = []
ecosystems_to_log: list[str] = []
Expand Down Expand Up @@ -421,7 +398,7 @@ async def on_environmental_parameters(
async with self.session(sid) as session:
session["init_data"].discard("environmental_parameters")
data: list[gv.EnvironmentConfigPayloadDict] = self.validate_payload(
data, gv.EnvironmentConfigPayload, list)
data, RootModel[list[gv.EnvironmentConfigPayload]])
ecosystems_to_log: list[str] = []
async with db.scoped_session() as session:
for payload in data:
Expand Down Expand Up @@ -469,7 +446,7 @@ async def on_hardware(
async with self.session(sid) as session:
session["init_data"].discard("hardware")
data: list[gv.HardwareConfigPayloadDict] = self.validate_payload(
data, gv.HardwareConfigPayload, list)
data, RootModel[list[gv.HardwareConfigPayload]])
ecosystems_to_log: list[str] = []
async with db.scoped_session() as session:
for payload in data:
Expand Down Expand Up @@ -516,7 +493,7 @@ async def on_management(
async with self.session(sid) as session:
session["init_data"].discard("management")
data: list[gv.ManagementConfigPayloadDict] = self.validate_payload(
data, gv.ManagementConfigPayload, list)
data, RootModel[list[gv.ManagementConfigPayload]])

class EcosystemUpdateData(TypedDict):
management: str
Expand Down Expand Up @@ -561,7 +538,7 @@ async def on_sensors_data(
self.logger.debug(
f"Received 'sensors_data' from engine: {engine_uid}")
data: list[gv.SensorsDataPayloadDict] = self.validate_payload(
data, gv.SensorsDataPayload, list)
data, RootModel[list[gv.SensorsDataPayload]])
sensors_data: list[SensorDataRecordDict] = []
alarms_data: list[SensorAlarmDict] = []
for ecosystem in data:
Expand Down Expand Up @@ -715,7 +692,7 @@ async def on_buffered_sensors_data(
self.logger.debug(
f"Received 'buffered_sensors_data' from {engine_uid}")
data: gv.BufferedSensorsDataPayloadDict = self.validate_payload(
data, gv.BufferedSensorsDataPayload, dict)
data, gv.BufferedSensorsDataPayload)
exchange_uuid: UUID = data["uuid"]
records = [
{
Expand Down Expand Up @@ -745,7 +722,7 @@ async def on_actuators_data(
async with self.session(sid) as session:
session["init_data"].discard("actuators_data")
data: list[gv.ActuatorsDataPayloadDict] = self.validate_payload(
data, gv.ActuatorsDataPayload, list)
data, RootModel[list[gv.ActuatorsDataPayload]])

class AwareActuatorStateDict(gv.ActuatorStateDict):
ecosystem_uid: str
Expand Down Expand Up @@ -814,7 +791,7 @@ async def on_buffered_actuators_data(
self.logger.debug(
f"Received 'buffered_actuators_data' from {engine_uid}")
data: gv.BufferedActuatorsStatePayloadDict = self.validate_payload(
data, gv.BufferedActuatorsStatePayload, dict)
data, gv.BufferedActuatorsStatePayload)
exchange_uuid: UUID = data["uuid"]
records = [
{
Expand Down Expand Up @@ -847,7 +824,7 @@ async def on_health_data(
async with self.session(sid) as session:
session["init_data"].discard("health_data")
data: list[gv.HealthDataPayloadDict] = self.validate_payload(
data, gv.HealthDataPayload, list)
data, RootModel[list[gv.HealthDataPayload]])
logged: list[str] = []
values: list[dict] = []
async with db.scoped_session() as session:
Expand Down Expand Up @@ -883,7 +860,7 @@ async def on_light_data(
async with self.session(sid) as session:
session["init_data"].discard("light_data")
data: list[gv.LightDataPayloadDict] = self.validate_payload(
data, gv.LightDataPayload, list)
data, RootModel[list[gv.LightDataPayload]])
ecosystems_to_log: list[str] = []
async with db.scoped_session() as session:
for payload in data:
Expand Down Expand Up @@ -912,7 +889,7 @@ async def _turn_actuator(
data: gv.TurnActuatorPayloadDict
) -> None:
data: gv.TurnActuatorPayloadDict = self.validate_payload(
data, gv.TurnActuatorPayload, dict)
data, gv.TurnActuatorPayload)
async with db.scoped_session() as session:
ecosystem_uid = data["ecosystem_uid"]
ecosystem = await Ecosystem.get(session, uid=ecosystem_uid)
Expand Down

0 comments on commit 900f747

Please sign in to comment.