diff --git a/open_bus_stride_api/routers/siri_vehicle_locations.py b/open_bus_stride_api/routers/siri_vehicle_locations.py index 57dd5ad..54bafb0 100644 --- a/open_bus_stride_api/routers/siri_vehicle_locations.py +++ b/open_bus_stride_api/routers/siri_vehicle_locations.py @@ -4,8 +4,10 @@ import pydantic from fastapi import APIRouter +from textwrap import dedent from open_bus_stride_db.model.siri_vehicle_location import SiriVehicleLocation +from open_bus_stride_db.db import get_session from open_bus_stride_db import model from . import siri_rides, siri_routes, siri_snapshots @@ -126,3 +128,127 @@ def get_(id: int = common.param_get_id(WHAT_SINGULAR)): SQL_MODEL, SQL_MODEL.id, id, pydantic_model=PYDANTIC_MODEL, ) + + + +class SIRI_AGG_VELOCITY_STATS_PYDANTIC_MODEL(pydantic.BaseModel): + lon_round: float + lat_round: float + date: datetime.date + stddev_hourly_avg: typing.Optional[float] + avg_hourly_avg: float + sample_number: int + median_hourly_avg: float + last_used: typing.Optional[datetime.datetime] + +@router.get( + "/siri-agg-velocity-stats", + tags=[TAG], + response_model=typing.List[SIRI_AGG_VELOCITY_STATS_PYDANTIC_MODEL], + description="Retrieve aggregated velocity stats for a given date (with cache mechanism).", +) +def get_or_insert_agg_velocity_stats( + date: datetime.date = common.doc_param("date", filter_type="date", default=...) +): + """ + Fetch aggregated velocity stats for the given date. If not found, calculate them. + """ + # Query the database for the given date + sql = dedent(""" + SELECT + siri_agg_velocity_stats.lat_round, + lon_round, + date, + stddev_hourly_avg, + avg_hourly_avg, + sample_number, + median_hourly_avg, + last_used + FROM siri_agg_velocity_stats + WHERE date = :date + """) + sql_params = {"date": date} + with get_session() as db: + results = db.execute(sql, sql_params).fetchall() + + # If no results, calculate and insert new data + if not results: + calculate_sql = dedent(""" + WITH HourlyAverages AS ( + SELECT + trunc(lon * 500 + .5) / 500 AS lon_round, + trunc(lat * 500 + .5) / 500 AS lat_round, + DATE(recorded_at_time) AS date, + DATE_PART('hour', recorded_at_time) AS hour, + AVG(velocity) AS hourly_avg, + COUNT(1) AS sample_number + FROM siri_vehicle_location svl + WHERE + velocity > 0 AND velocity < 200 + AND lon > 0 AND lat > 0 + AND DATE(recorded_at_time) = :date + GROUP BY lon_round, lat_round, date, hour + HAVING COUNT(1) > 5 + ) + SELECT + lon_round, + lat_round, + date, + STDDEV(hourly_avg) AS stddev_hourly_avg, + AVG(hourly_avg) AS avg_hourly_avg, + SUM(sample_number) AS sample_number, + PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY hourly_avg) AS median_hourly_avg + FROM HourlyAverages + WHERE date = :date + GROUP BY lon_round, lat_round, date + """) + new_data = db.execute(calculate_sql, sql_params).fetchall() + + if new_data: + # delete old data (inserted 10,000 records ago) + delete_old_sql = dedent(""" + with delete_from as ( + select last_used + from siri_agg_velocity_stats + order by last_used + limit 1 + offset 10000 + ) + delete from siri_agg_velocity_stats + where last_used < (select last_used from delete_from) + """) + insert_sql = dedent(""" + INSERT INTO siri_agg_velocity_stats ( + lon_round, lat_round, date, + stddev_hourly_avg, avg_hourly_avg, + sample_number, median_hourly_avg, + last_used + ) + VALUES ( + :lon_round, :lat_round, :date, + :stddev_hourly_avg, :avg_hourly_avg, + :sample_number, :median_hourly_avg, + NOW() + ) + """) + for row in new_data: + db.execute(insert_sql, dict(row)) + db.commit() + + # Refresh the results after insertion + results = db.execute(sql, sql_params).fetchall() + + # Return the results as a response + return [ + { + "lon_round": result.lon_round, + "lat_round": result.lat_round, + "date": result.date, + "stddev_hourly_avg": result.stddev_hourly_avg, + "avg_hourly_avg": result.avg_hourly_avg, + "sample_number": result.sample_number, + "median_hourly_avg": result.median_hourly_avg, + "last_used": result.last_used, + } + for result in results + ]