diff --git a/heatchmap/gpmap.py b/heatchmap/gpmap.py index 7a78183..20bd312 100644 --- a/heatchmap/gpmap.py +++ b/heatchmap/gpmap.py @@ -22,12 +22,20 @@ from map_based_model import MapBasedModel from utils.utils_models import fit_gpr_silent import glob +from huggingface_hub import hf_hub_download class GPMap(MapBasedModel): def __init__(self, region="world", resolution=10, version="prod"): - self.gpr_path = "models/kernel.pkl" + self.points_path = "dump.sqlite" + if os.path.exists("models/kernel.pkl"): + self.gpr_path = "models/kernel.pkl" + else: + REPO_ID = "tillwenke/heatchmap-model" + FILENAME = "Unfitted_GaussianProcess_TransformedTargetRegressorWithUncertainty.pkl" + self.gpr_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME) + with open(self.gpr_path, "rb") as file: self.gpr = pickle.load(file) @@ -247,4 +255,4 @@ def get_landmass_raster(self): def recalc(): gpmap = GPMap(resolution=1) - gpmap.recalc_map() \ No newline at end of file + gpmap.recalc_map()