diff --git a/mapreader/spot_text/runner_base.py b/mapreader/spot_text/runner_base.py index 77c1ee71..352c7378 100644 --- a/mapreader/spot_text/runner_base.py +++ b/mapreader/spot_text/runner_base.py @@ -16,7 +16,7 @@ from tqdm.auto import tqdm from mapreader import MapImages -from mapreader.utils.load_frames import load_from_csv, load_from_geojson +from mapreader.utils.load_frames import eval_dataframe, load_from_csv, load_from_geojson from .dataclasses import GeoPrediction, ParentPrediction, PatchPrediction @@ -478,18 +478,18 @@ def convert_to_coords( def save_to_geojson( self, - save_path: str | pathlib.Path, + path_save: str | pathlib.Path, centroid: bool = False, ) -> None: """Save the georeferenced predictions to a GeoJSON file. Parameters ---------- - save_path : str | pathlib.Path, optional + path_save : str | pathlib.Path, optional Path to save the GeoJSON file centroid : bool, optional - Whether to save the centroid of the polygons as the geometry column, by default False. - Note: The original polygon will stil be saved as a separate column. + Whether to convert the polygons to centroids, by default False. + NOTE: The original polygon will still be saved as a separate column """ if self.geo_predictions == {}: raise ValueError( @@ -500,12 +500,61 @@ def save_to_geojson( if centroid: geo_df["polygon"] = geo_df["geometry"].to_wkt() - geo_df["geometry"] = geo_df["geometry"].apply(self._polygon_to_centroid) + geo_df["geometry"] = geo_df["geometry"].centroid - geo_df.to_file(save_path, driver="GeoJSON", engine="pyogrio") + geo_df.to_file(path_save, driver="GeoJSON", engine="pyogrio") - def _polygon_to_centroid(self, polygon): - return polygon.centroid + def save_to_csv( + self, + path_save: str | pathlib.Path, + centroid: bool = False, + ) -> None: + """Saves the patch, parent and georeferenced predictions to CSV files. + + Parameters + ---------- + path_save : str | pathlib.Path + The path to save the CSV files. Files will be saved as `patch_predictions.csv`, `parent_predictions.csv` and `geo_predictions.csv`. + centroid : bool, optional + Whether to convert polygons to centroids, by default False. + NOTE: The original polygon will still be saved as a separate column. + + Note + ---- + Use the `save_to_geojson` method to save georeferenced predictions to a GeoJSON file. + """ + if self.patch_predictions == {}: # implies no parent or geo predictions + raise ValueError("[ERROR] No patch predictions found.") + + if not os.path.exists(path_save): + os.makedirs(path_save) + + print("[INFO] Saving patch predictions.") + patch_df = self._dict_to_dataframe(self.patch_predictions) + if centroid: + patch_df["polygon"] = patch_df["pixel_geometry"] + patch_df["pixel_geometry"] = patch_df["pixel_geometry"].apply( + lambda x: x.centroid + ) + patch_df.to_csv(f"{path_save}/patch_predictions.csv") + + if self.parent_predictions != {}: + print("[INFO] Saving parent predictions.") + parent_df = self._dict_to_dataframe(self.parent_predictions) + if centroid: + parent_df["polygon"] = parent_df["pixel_geometry"] + parent_df["pixel_geometry"] = parent_df["pixel_geometry"].apply( + lambda x: x.centroid + ) + parent_df.to_csv(f"{path_save}/parent_predictions.csv") + + if self.geo_predictions != {}: + print("[INFO] Saving geo predictions.") + geo_df = self._dict_to_dataframe(self.geo_predictions) + if centroid: + geo_df["polygon"] = geo_df["geometry"] + geo_df["geometry"] = geo_df["geometry"].centroid + geo_df.to_csv(f"{path_save}/geo_predictions.csv") def show_predictions( self, @@ -604,15 +653,15 @@ def explore_predictions( style_kwds=style_kwargs, ) - def load_predictions( + def load_geo_predictions( self, - path_save: str | pathlib.Path, + load_path: str | pathlib.Path, ): """Load georeferenced text predictions from a GeoJSON file. Parameters ---------- - path_save : str | pathlib.Path + load_path : str | pathlib.Path The path to the GeoJSON file. Raises @@ -624,10 +673,10 @@ def load_predictions( ---- This will overwrite any existing predictions! """ - if re.search(r"\..*?json$", str(path_save)): - preds_df = load_from_geojson(path_save, engine="pyogrio") + if re.search(r"\..*?json$", str(load_path)): + preds_df = load_from_geojson(load_path, engine="pyogrio") else: - raise ValueError("[ERROR] ``path_save`` must be a path to a geojson file.") + raise ValueError("[ERROR] ``load_path`` must be a path to a geojson file.") # convert pixel_geometry to shapely geometry preds_df["pixel_geometry"] = preds_df["pixel_geometry"].apply( @@ -648,7 +697,7 @@ def load_predictions( GeoPrediction( pixel_geometry=v.pixel_geometry, score=v.score, - text=v.text, + text=v.text if "text" in v.index else None, patch_id=v.patch_id, geometry=v.geometry, crs=v.crs, @@ -658,7 +707,7 @@ def load_predictions( ParentPrediction( pixel_geometry=v.pixel_geometry, score=v.score, - text=v.text, + text=v.text if "text" in v.index else None, patch_id=v.patch_id, ) ) @@ -689,6 +738,49 @@ def load_predictions( ) ) + def load_patch_predictions( + self, + patch_preds: str | pathlib.Path | pd.DataFrame, + ) -> None: + if not isinstance(patch_preds, pd.DataFrame): + if re.search(r"\..*?csv$", str(patch_preds)): + patch_preds = pd.read_csv(patch_preds, index_col=0) + patch_preds = eval_dataframe(patch_preds) + else: + raise ValueError( + "[ERROR] ``patch_preds`` must be a pandas DataFrame or path to a CSV file." + ) + + # if we have a polygon column, this implies the pixel_geometry column is the centroid + if "polygon" in patch_preds.columns: + patch_preds["pixel_geometry"] = patch_preds["polygon"] + patch_preds.drop(columns=["polygon"], inplace=True) + + # convert pixel_geometry to shapely geometry + patch_preds["pixel_geometry"] = patch_preds["pixel_geometry"].apply( + lambda x: from_wkt(x) + ) + + self.patch_predictions = {} # reset patch predictions + + for image_id in patch_preds["image_id"].unique(): + if image_id not in self.patch_predictions.keys(): + self.patch_predictions[image_id] = [] + + for _, v in patch_preds[patch_preds["image_id"] == image_id].iterrows(): + self.patch_predictions[image_id].append( + PatchPrediction( + pixel_geometry=v.pixel_geometry, + score=v.score, + text=v.text if "text" in v.index else None, + ) + ) + + self.geo_predictions = {} + self.parent_predictions = {} + + self.convert_to_parent_pixel_bounds() + class DetRecRunner(DetRunner): def _get_patch_predictions( @@ -950,14 +1042,14 @@ def explore_search_results( def save_search_results_to_geojson( self, - save_path: str | pathlib.Path, + path_save: str | pathlib.Path, centroid: bool = False, ) -> None: """Convert the search results to georeferenced search results and save them to a GeoJSON file. Parameters ---------- - save_path : str | pathlib.Path + path_save : str | pathlib.Path The path to save the GeoJSON file. centroid : bool, optional Whether to save the centroid of the polygons as the geometry column, by default False. @@ -976,6 +1068,6 @@ def save_search_results_to_geojson( if centroid: geo_df["polygon"] = geo_df["geometry"].to_wkt() - geo_df["geometry"] = geo_df["geometry"].apply(self._polygon_to_centroid) + geo_df["geometry"] = geo_df["geometry"].centroid - geo_df.to_file(save_path, driver="GeoJSON", engine="pyogrio") + geo_df.to_file(path_save, driver="GeoJSON", engine="pyogrio")