Skip to content

Commit

Permalink
Merge branch 'add_model_scores' into issue/283
Browse files Browse the repository at this point in the history
  • Loading branch information
2320sharon committed Jan 21, 2025
2 parents 6d1b65f + cd9efdd commit 11df944
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 11 deletions.
1 change: 1 addition & 0 deletions 5_zoo_workflow_local_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"tta": False, # Test Time Augmentation
"use_local_model": True, # Use local model (not one from zeneodo)
"local_model_path": r"C:\development\doodleverse\coastseg\CoastSeg\src\coastseg\downloaded_models\non_validation_model", # path to the local model
"apply_segmentation_filter": True, # apply segmentation filter to the model outputs to sort them into good or bad
}

# Available models can run input "RGB" # or "MNDWI" or "NDWI"
Expand Down
1 change: 1 addition & 0 deletions 6_zoo_workflow_with_coregistration.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"model_type": "global_segformer_RGB_4class_14036903", # model name from the zoo
"otsu": False, # Otsu Thresholding
"tta": False, # Test Time Augmentation
"apply_segmentation_filter": True, # apply segmentation filter to the model outputs to sort them into good or bad
}
# Available models can run input "RGB" # or "MNDWI" or "NDWI"
img_type = "RGB" # make sure the model name is compatible with the image type
Expand Down
25 changes: 14 additions & 11 deletions src/coastseg/classifier.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import glob
import pandas as pd
import numpy as np
import shutil
import pooch
import tensorflow as tf
Expand All @@ -12,6 +13,7 @@
# Some of these functions were originally written by Mark Lundine and have been modified for this project.



def filter_segmentations(
session_path: str,
threshold: float = 0.40,
Expand Down Expand Up @@ -173,13 +175,13 @@ def run_inference_rgb_image_classifier(path_to_model_ckpt,
path_to_model_ckpt (str): path to the saved keras model
path_to_inference_imgs (str): path to the folder containing images to run the model on
output_folder (str): path to save outputs to
csv_path (str): csv path to save results to. If not provided, the results will be saved to output_folder/classification_results.csv
csv_path (str): csv path to save results to. If not provided, the results will be saved to output_folder/image_classification_results.csv
threshold (float): threshold on sigmoid of model output (ex: 0.6 means mark images as good if model output is >= 0.6, or 60% sure it's a good image)
returns:
csv_path (str): csv path of saved results
"""
if not csv_path:
csv_path = os.path.join(output_folder, 'classification_results.csv')
csv_path = os.path.join(output_folder, 'image_classification_results.csv')

os.makedirs(output_folder,exist_ok=True)

Expand All @@ -194,7 +196,6 @@ def run_inference_rgb_image_classifier(path_to_model_ckpt,
im_classes = [None]*len(im_paths)
i=0
for im_path in im_paths:
print(im_path)
img = keras.utils.load_img(im_path, color_mode='rgb',target_size=image_size)
img_array = keras.utils.img_to_array(img)
img_array = tf.expand_dims(img_array, 0)
Expand All @@ -204,7 +205,8 @@ def run_inference_rgb_image_classifier(path_to_model_ckpt,
i=i+1
##save results to a csv
df = pd.DataFrame({'im_paths':im_paths,
'model_scores':model_scores
'model_scores':model_scores,
'threshold':np.full(len(im_paths), threshold)
}
)

Expand All @@ -229,13 +231,13 @@ def run_inference_gray_image_classifier(path_to_model_ckpt,
path_to_model_ckpt (str): path to the saved keras model
path_to_inference_imgs (str): path to the folder containing images to run the model on
output_folder (str): path to save outputs to
csv_path (str): csv path to save results to. If not provided, the results will be saved to output_folder/classification_results.csv
csv_path (str): csv path to save results to. If not provided, the results will be saved to output_folder/image_classification_results.csv
threshold (float): threshold on sigmoid of model output (ex: 0.6 means mark images as good if model output is >= 0.6, or 60% sure it's a good image)
returns:
csv_path (str): csv path of saved results
"""
if not csv_path:
csv_path = os.path.join(output_folder, 'classification_results.csv')
csv_path = os.path.join(output_folder, 'image_classification_results.csv')

os.makedirs(output_folder,exist_ok=True)
image_size = (128, 128)
Expand All @@ -258,7 +260,8 @@ def run_inference_gray_image_classifier(path_to_model_ckpt,
i=i+1
##save results to a csv
df = pd.DataFrame({'im_paths':im_paths,
'model_scores':model_scores
'model_scores':model_scores,
'threshold':np.full(len(im_paths), threshold)
}
)
df.to_csv(csv_path)
Expand Down Expand Up @@ -349,7 +352,6 @@ def get_image_classifier(type:str='rgb') -> str:
)
else: # get the grayscale model
model_name ='ImageGrayClassifier'
print(model_name)
model_directory = file_utilities.create_directory(
downloaded_models_path, model_name
)
Expand Down Expand Up @@ -429,7 +431,7 @@ def run_inference_segmentation_classifier(path_to_model_ckpt:str,
path_to_inference_imgs (str): path to the folder containing images to run the model on
output_folder (str): path to save outputs to
csv_path (str): csv path to save results to
If not provided, the results will be saved to output_folder/classification_results.csv
If not provided, the results will be saved to output_folder/image_classification_results.csv
threshold (float): threshold on sigmoid of model output (ex: 0.6 means mark images as good if model output is >= 0.6, or 60% sure it's a good image)
returns:
Expand Down Expand Up @@ -478,12 +480,13 @@ def run_inference_segmentation_classifier(path_to_model_ckpt:str,
i=i+1
##save results to a csv
df = pd.DataFrame({'im_paths':im_paths,
'model_scores':model_scores
'model_scores':model_scores,
'threshold':np.full(len(im_paths), threshold)
}
)

if not csv_path:
csv_path = os.path.join(output_folder, 'classification_results.csv')
csv_path = os.path.join(output_folder, 'segmentation_classification_results.csv')

df.to_csv(csv_path)
good_path,bad_path=sort_images(csv_path,
Expand Down

0 comments on commit 11df944

Please sign in to comment.