Skip to content

Commit

Permalink
add save to /load from csv
Browse files Browse the repository at this point in the history
  • Loading branch information
rwood-97 committed Dec 4, 2024
1 parent 3af480b commit 807e715
Showing 1 changed file with 113 additions and 21 deletions.
134 changes: 113 additions & 21 deletions mapreader/spot_text/runner_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from tqdm.auto import tqdm

from mapreader import MapImages
from mapreader.utils.load_frames import load_from_csv, load_from_geojson
from mapreader.utils.load_frames import eval_dataframe, load_from_csv, load_from_geojson

from .dataclasses import GeoPrediction, ParentPrediction, PatchPrediction

Expand Down Expand Up @@ -478,18 +478,18 @@ def convert_to_coords(

def save_to_geojson(
self,
save_path: str | pathlib.Path,
path_save: str | pathlib.Path,
centroid: bool = False,
) -> None:
"""Save the georeferenced predictions to a GeoJSON file.
Parameters
----------
save_path : str | pathlib.Path, optional
path_save : str | pathlib.Path, optional
Path to save the GeoJSON file
centroid : bool, optional
Whether to save the centroid of the polygons as the geometry column, by default False.
Note: The original polygon will stil be saved as a separate column.
Whether to convert the polygons to centroids, by default False.
NOTE: The original polygon will still be saved as a separate column
"""
if self.geo_predictions == {}:
raise ValueError(
Expand All @@ -500,12 +500,61 @@ def save_to_geojson(

if centroid:
geo_df["polygon"] = geo_df["geometry"].to_wkt()
geo_df["geometry"] = geo_df["geometry"].apply(self._polygon_to_centroid)
geo_df["geometry"] = geo_df["geometry"].centroid

Check warning on line 503 in mapreader/spot_text/runner_base.py

View check run for this annotation

Codecov / codecov/patch

mapreader/spot_text/runner_base.py#L502-L503

Added lines #L502 - L503 were not covered by tests

geo_df.to_file(save_path, driver="GeoJSON", engine="pyogrio")
geo_df.to_file(path_save, driver="GeoJSON", engine="pyogrio")

def _polygon_to_centroid(self, polygon):
return polygon.centroid
def save_to_csv(
self,
path_save: str | pathlib.Path,
centroid: bool = False,
) -> None:
"""Saves the patch, parent and georeferenced predictions to CSV files.
Parameters
----------
path_save : str | pathlib.Path
The path to save the CSV files. Files will be saved as `patch_predictions.csv`, `parent_predictions.csv` and `geo_predictions.csv`.
centroid : bool, optional
Whether to convert polygons to centroids, by default False.
NOTE: The original polygon will still be saved as a separate column.
Note
----
Use the `save_to_geojson` method to save georeferenced predictions to a GeoJSON file.
"""
if self.patch_predictions == {}: # implies no parent or geo predictions
raise ValueError("[ERROR] No patch predictions found.")

Check warning on line 527 in mapreader/spot_text/runner_base.py

View check run for this annotation

Codecov / codecov/patch

mapreader/spot_text/runner_base.py#L526-L527

Added lines #L526 - L527 were not covered by tests

if not os.path.exists(path_save):
os.makedirs(path_save)

Check warning on line 530 in mapreader/spot_text/runner_base.py

View check run for this annotation

Codecov / codecov/patch

mapreader/spot_text/runner_base.py#L529-L530

Added lines #L529 - L530 were not covered by tests

print("[INFO] Saving patch predictions.")
patch_df = self._dict_to_dataframe(self.patch_predictions)
if centroid:
patch_df["polygon"] = patch_df["pixel_geometry"]
patch_df["pixel_geometry"] = patch_df["pixel_geometry"].apply(

Check warning on line 536 in mapreader/spot_text/runner_base.py

View check run for this annotation

Codecov / codecov/patch

mapreader/spot_text/runner_base.py#L532-L536

Added lines #L532 - L536 were not covered by tests
lambda x: x.centroid
)
patch_df.to_csv(f"{path_save}/patch_predictions.csv")

Check warning on line 539 in mapreader/spot_text/runner_base.py

View check run for this annotation

Codecov / codecov/patch

mapreader/spot_text/runner_base.py#L539

Added line #L539 was not covered by tests

if self.parent_predictions != {}:
print("[INFO] Saving parent predictions.")
parent_df = self._dict_to_dataframe(self.parent_predictions)
if centroid:
parent_df["polygon"] = parent_df["pixel_geometry"]
parent_df["pixel_geometry"] = parent_df["pixel_geometry"].apply(

Check warning on line 546 in mapreader/spot_text/runner_base.py

View check run for this annotation

Codecov / codecov/patch

mapreader/spot_text/runner_base.py#L541-L546

Added lines #L541 - L546 were not covered by tests
lambda x: x.centroid
)
parent_df.to_csv(f"{path_save}/parent_predictions.csv")

Check warning on line 549 in mapreader/spot_text/runner_base.py

View check run for this annotation

Codecov / codecov/patch

mapreader/spot_text/runner_base.py#L549

Added line #L549 was not covered by tests

if self.geo_predictions != {}:
print("[INFO] Saving geo predictions.")
geo_df = self._dict_to_dataframe(self.geo_predictions)
if centroid:
geo_df["polygon"] = geo_df["geometry"]
geo_df["geometry"] = geo_df["geometry"].centroid
geo_df.to_csv(f"{path_save}/geo_predictions.csv")

Check warning on line 557 in mapreader/spot_text/runner_base.py

View check run for this annotation

Codecov / codecov/patch

mapreader/spot_text/runner_base.py#L551-L557

Added lines #L551 - L557 were not covered by tests

def show_predictions(
self,
Expand Down Expand Up @@ -604,15 +653,15 @@ def explore_predictions(
style_kwds=style_kwargs,
)

def load_predictions(
def load_geo_predictions(
self,
path_save: str | pathlib.Path,
load_path: str | pathlib.Path,
):
"""Load georeferenced text predictions from a GeoJSON file.
Parameters
----------
path_save : str | pathlib.Path
load_path : str | pathlib.Path
The path to the GeoJSON file.
Raises
Expand All @@ -624,10 +673,10 @@ def load_predictions(
----
This will overwrite any existing predictions!
"""
if re.search(r"\..*?json$", str(path_save)):
preds_df = load_from_geojson(path_save, engine="pyogrio")
if re.search(r"\..*?json$", str(load_path)):
preds_df = load_from_geojson(load_path, engine="pyogrio")

Check warning on line 677 in mapreader/spot_text/runner_base.py

View check run for this annotation

Codecov / codecov/patch

mapreader/spot_text/runner_base.py#L676-L677

Added lines #L676 - L677 were not covered by tests
else:
raise ValueError("[ERROR] ``path_save`` must be a path to a geojson file.")
raise ValueError("[ERROR] ``load_path`` must be a path to a geojson file.")

Check warning on line 679 in mapreader/spot_text/runner_base.py

View check run for this annotation

Codecov / codecov/patch

mapreader/spot_text/runner_base.py#L679

Added line #L679 was not covered by tests

# convert pixel_geometry to shapely geometry
preds_df["pixel_geometry"] = preds_df["pixel_geometry"].apply(

Check warning on line 682 in mapreader/spot_text/runner_base.py

View check run for this annotation

Codecov / codecov/patch

mapreader/spot_text/runner_base.py#L682

Added line #L682 was not covered by tests
Expand All @@ -648,7 +697,7 @@ def load_predictions(
GeoPrediction(
pixel_geometry=v.pixel_geometry,
score=v.score,
text=v.text,
text=v.text if "text" in v.index else None,
patch_id=v.patch_id,
geometry=v.geometry,
crs=v.crs,
Expand All @@ -658,7 +707,7 @@ def load_predictions(
ParentPrediction(
pixel_geometry=v.pixel_geometry,
score=v.score,
text=v.text,
text=v.text if "text" in v.index else None,
patch_id=v.patch_id,
)
)
Expand Down Expand Up @@ -689,6 +738,49 @@ def load_predictions(
)
)

def load_patch_predictions(
self,
patch_preds: str | pathlib.Path | pd.DataFrame,
) -> None:
if not isinstance(patch_preds, pd.DataFrame):
if re.search(r"\..*?csv$", str(patch_preds)):
patch_preds = pd.read_csv(patch_preds, index_col=0)
patch_preds = eval_dataframe(patch_preds)

Check warning on line 748 in mapreader/spot_text/runner_base.py

View check run for this annotation

Codecov / codecov/patch

mapreader/spot_text/runner_base.py#L745-L748

Added lines #L745 - L748 were not covered by tests
else:
raise ValueError(

Check warning on line 750 in mapreader/spot_text/runner_base.py

View check run for this annotation

Codecov / codecov/patch

mapreader/spot_text/runner_base.py#L750

Added line #L750 was not covered by tests
"[ERROR] ``patch_preds`` must be a pandas DataFrame or path to a CSV file."
)

# if we have a polygon column, this implies the pixel_geometry column is the centroid
if "polygon" in patch_preds.columns:
patch_preds["pixel_geometry"] = patch_preds["polygon"]
patch_preds.drop(columns=["polygon"], inplace=True)

Check warning on line 757 in mapreader/spot_text/runner_base.py

View check run for this annotation

Codecov / codecov/patch

mapreader/spot_text/runner_base.py#L755-L757

Added lines #L755 - L757 were not covered by tests

# convert pixel_geometry to shapely geometry
patch_preds["pixel_geometry"] = patch_preds["pixel_geometry"].apply(

Check warning on line 760 in mapreader/spot_text/runner_base.py

View check run for this annotation

Codecov / codecov/patch

mapreader/spot_text/runner_base.py#L760

Added line #L760 was not covered by tests
lambda x: from_wkt(x)
)

self.patch_predictions = {} # reset patch predictions

Check warning on line 764 in mapreader/spot_text/runner_base.py

View check run for this annotation

Codecov / codecov/patch

mapreader/spot_text/runner_base.py#L764

Added line #L764 was not covered by tests

for image_id in patch_preds["image_id"].unique():
if image_id not in self.patch_predictions.keys():
self.patch_predictions[image_id] = []

Check warning on line 768 in mapreader/spot_text/runner_base.py

View check run for this annotation

Codecov / codecov/patch

mapreader/spot_text/runner_base.py#L766-L768

Added lines #L766 - L768 were not covered by tests

for _, v in patch_preds[patch_preds["image_id"] == image_id].iterrows():
self.patch_predictions[image_id].append(

Check warning on line 771 in mapreader/spot_text/runner_base.py

View check run for this annotation

Codecov / codecov/patch

mapreader/spot_text/runner_base.py#L770-L771

Added lines #L770 - L771 were not covered by tests
PatchPrediction(
pixel_geometry=v.pixel_geometry,
score=v.score,
text=v.text if "text" in v.index else None,
)
)

self.geo_predictions = {}
self.parent_predictions = {}

Check warning on line 780 in mapreader/spot_text/runner_base.py

View check run for this annotation

Codecov / codecov/patch

mapreader/spot_text/runner_base.py#L779-L780

Added lines #L779 - L780 were not covered by tests

self.convert_to_parent_pixel_bounds()

Check warning on line 782 in mapreader/spot_text/runner_base.py

View check run for this annotation

Codecov / codecov/patch

mapreader/spot_text/runner_base.py#L782

Added line #L782 was not covered by tests


class DetRecRunner(DetRunner):
def _get_patch_predictions(
Expand Down Expand Up @@ -950,14 +1042,14 @@ def explore_search_results(

def save_search_results_to_geojson(
self,
save_path: str | pathlib.Path,
path_save: str | pathlib.Path,
centroid: bool = False,
) -> None:
"""Convert the search results to georeferenced search results and save them to a GeoJSON file.
Parameters
----------
save_path : str | pathlib.Path
path_save : str | pathlib.Path
The path to save the GeoJSON file.
centroid : bool, optional
Whether to save the centroid of the polygons as the geometry column, by default False.
Expand All @@ -976,6 +1068,6 @@ def save_search_results_to_geojson(

if centroid:
geo_df["polygon"] = geo_df["geometry"].to_wkt()
geo_df["geometry"] = geo_df["geometry"].apply(self._polygon_to_centroid)
geo_df["geometry"] = geo_df["geometry"].centroid

Check warning on line 1071 in mapreader/spot_text/runner_base.py

View check run for this annotation

Codecov / codecov/patch

mapreader/spot_text/runner_base.py#L1070-L1071

Added lines #L1070 - L1071 were not covered by tests

geo_df.to_file(save_path, driver="GeoJSON", engine="pyogrio")
geo_df.to_file(path_save, driver="GeoJSON", engine="pyogrio")

0 comments on commit 807e715

Please sign in to comment.