Skip to content

Commit

Permalink
with uncertainties
Browse files Browse the repository at this point in the history
  • Loading branch information
tillwenke committed Jan 12, 2025
1 parent 0453e72 commit 59ae327
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 35 deletions.
62 changes: 28 additions & 34 deletions heatchmap/gpmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,40 +65,38 @@ def __init__(self, region="world", resolution=10, version="prod", visual:bool=Fa

map_dataset_dict = load_dataset("tillwenke/heatchmap-map", cache_dir=f"{HERE}/cache/huggingface")
# choosing the latest map; dataset splits are dates
split = list(map_dataset_dict.keys())[-1]
logger.info(f"Loading map from {split}.")
self.map_dataset = map_dataset_dict[split]
self.map_dataset = self.map_dataset.with_format("np")
self.raw_raster = self.map_dataset["numpy"]
splits = list(map_dataset_dict.keys())
if len(splits) == 0:
logger.info("No map found in huggingface dataset. Recalculating whole map.")
self.map_dataset = None
self.raw_raster = None
self.uncertainties = None
else:
split = splits[-1]
logger.info(f"Loading map from {split}.")
self.map_dataset = map_dataset_dict[split]
self.map_dataset = self.map_dataset.with_format("np")
self.raw_raster = self.map_dataset["numpy"]
self.uncertainties = self.map_dataset["uncertainties"]

# files = glob.glob(f"intermediate/map_{self.method}_{self.region}_{self.resolution}_{self.version}*.txt")

self.today = pd.Timestamp.now()

try:
self.begin = pd.Timestamp.strptime(split, "%Y.%m.%d")
logger.info(f"Last map update was on {self.begin.date()}.")
except Exception as e:
self.begin = pd.Timestamp(self.today.date() - pd.Timedelta(days=1))
logger.info(f"No map update found with {e}. Using yesterday as begin date.")

self.batch_size = 10000

# self.map_path = f"intermediate/map_{self.method}_{self.region}_{self.resolution}_{self.version}_{self.today.date()}.txt"

self.batch_size = 10000
self.recalc_radius = 800000 # TODO: determine from model largest influence radius

self.shapely_countries = f"{self.cache_dir}/countries/ne_110m_admin_0_countries.shp"

if not os.path.exists(self.shapely_countries):
output_dir = f"{self.cache_dir}/countries"
os.makedirs(output_dir, exist_ok=True)

# URL for the 110m countries shapefile from Natural Earth
url = "https://naturalearth.s3.amazonaws.com/110m_cultural/ne_110m_admin_0_countries.zip"



# Download the dataset
logger.info("Downloading countries dataset...")
response = requests.get(url)
Expand Down Expand Up @@ -131,10 +129,6 @@ def recalc_map(self):
self.gpr.regressor.optimizer = None
self.gpr = fit_gpr_silent(self.gpr, X, y)

# recalc the old map

# self.raw_raster = np.loadtxt(self.old_map_path)

self.get_map_grid()
self.get_recalc_raster()

Expand All @@ -153,17 +147,20 @@ def recalc_map(self):
pixels_to_predict.append((y, x))
# batching the model calls
if len(to_predict) == self.batch_size:
prediction = self.gpr.predict(np.array(to_predict), return_std=False)
waiting_times, uncertainties = self.gpr.predict(np.array(to_predict), return_std=True)
for i, (y, x) in enumerate(pixels_to_predict):
self.raw_raster[y][x] = prediction[i]
self.raw_raster[y][x] = waiting_time[i]

Check failure on line 152 in heatchmap/gpmap.py

View workflow job for this annotation

GitHub Actions / build

Ruff (F821)

heatchmap/gpmap.py:152:49: F821 Undefined name `waiting_time`

Check failure on line 152 in heatchmap/gpmap.py

View workflow job for this annotation

GitHub Actions / build

Ruff (F821)

heatchmap/gpmap.py:152:49: F821 Undefined name `waiting_time`
self.uncertainties[y][x] = uncertainties[i]

to_predict = []
pixels_to_predict = []

if len(to_predict) > 0:
prediction = self.gpr.predict(np.array(to_predict), return_std=False)
waiting_time, uncertainties = self.gpr.predict(np.array(to_predict), return_std=True)
for i, (y, x) in enumerate(pixels_to_predict):
self.raw_raster[y][x] = prediction[i]
self.raw_raster[y][x] = waiting_time[i]
self.uncertainties[y][x] = uncertainties[i]


logger.info(f"Time elapsed to compute full map: {time.time() - start}")
logger.info(
Expand All @@ -172,9 +169,6 @@ def recalc_map(self):
logger.info(f"Only {self.recalc_raster.sum()} pixels were recalculated. That is {self.recalc_raster.sum() / (self.raw_raster.shape[0] * self.raw_raster.shape[1]) * 100}% of the map.")

Check failure on line 169 in heatchmap/gpmap.py

View workflow job for this annotation

GitHub Actions / build

Ruff (E501)

heatchmap/gpmap.py:169:131: E501 Line too long (191 > 130)

Check failure on line 169 in heatchmap/gpmap.py

View workflow job for this annotation

GitHub Actions / build

Ruff (E501)

heatchmap/gpmap.py:169:131: E501 Line too long (191 > 130)
logger.info(f"And time per recalculated pixel was {(time.time() - start) / self.recalc_raster.sum()} seconds")

# np.savetxt(self.map_path, self.raw_raster)
# self.save_as_rasterio()

def show_raster(self, raster: np.array):
"""Show the raster in a plot.
Expand Down Expand Up @@ -213,7 +207,7 @@ def get_recalc_raster(self):
recalc_radius_pixels = int(np.ceil(abs(self.recalc_radius / (self.grid[0][0][0] - self.grid[0][0][1]))))
self.get_landmass_raster()

if self.raw_raster is None:
if self.raw_raster is None or self.uncertainties is None:
logger.info("No map found. Recalculating whole map.")
self.recalc_raster = np.ones(self.grid.shape[1:])
else:
Expand Down Expand Up @@ -315,18 +309,18 @@ def get_landmass_raster(self):
# cleanup
os.remove(self.landmass_path)

def upload(self):
def upload(self, latest_timestamp_in_dataset: pd.Timestamp = self.today):
"""Uploads the recalculated map to the Hugging Face model hub.
Clean cached files.
"""
logger.info(f"Shape of uploading map: {self.raw_raster.shape}")
d = {"numpy": self.raw_raster}
ds = Dataset.from_dict(d)
ds = ds.with_format("np")
ds_dict = DatasetDict({self.today.strftime("%Y.%m.%d"): ds})
data_dict = {"numpy": self.raw_raster, "uncertainties": self.uncertainties}
dataset = Dataset.from_dict(data_dict)
dataset = dataset.with_format("np")
dataset_dict = DatasetDict({latest_timestamp_in_dataset.strftime("%Y.%m.%d"): dataset})

ds_dict.push_to_hub("tillwenke/heatchmap-map")
dataset_dict.push_to_hub("tillwenke/heatchmap-map")
logger.info("Uploaded new map to Hugging Face dataset hub.")

shutil.rmtree(self.cache_dir)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import setuptools

VERSION = "0.1.22"
VERSION = "0.1.23"

NAME = "heatchmap"

Expand Down

0 comments on commit 59ae327

Please sign in to comment.