-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #99 from GAISSA-UPC/pluginNewFormats
Plugin new formats
- Loading branch information
Showing
10 changed files
with
227 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import os, csv | ||
import tensorflow as tf | ||
from tensorflow.keras.models import load_model | ||
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph | ||
import torch | ||
from fvcore.nn import FlopCountAnalysis | ||
|
||
def write_to_csv(output_path, values): | ||
# Ensure the directory exists | ||
os.makedirs(os.path.dirname(output_path), exist_ok=True) | ||
|
||
with open(output_path, 'w', newline='') as file: | ||
writer = csv.writer(file) | ||
|
||
# Write the header | ||
writer.writerow(['model_size', 'file_size', 'flops']) | ||
|
||
# Write the values | ||
writer.writerow(values) | ||
|
||
def calculateH5(model_path, output_path): | ||
model = load_model(model_path) | ||
|
||
#### MODEL SIZE (# parameters) | ||
# Get number of parameters | ||
num_params = model.count_params() | ||
|
||
#### MODEL FILE SIZE | ||
model_size_bytes = os.path.getsize(model_path) | ||
model_size_mb = model_size_bytes / (1024 * 1024) | ||
|
||
#### FLOPS | ||
# convert tf.keras model into frozen graph to count FLOPS about operations used at inference | ||
# FLOPS depends on batch size | ||
inputs = [ | ||
tf.TensorSpec([1] + (list(inp.shape[1:]) if not isinstance(inp.shape[1:], list) else inp.shape[1:]), inp.dtype) | ||
for inp in model.inputs | ||
] | ||
real_model = tf.function(model).get_concrete_function(inputs) | ||
frozen_func, _ = convert_variables_to_constants_v2_as_graph(real_model) | ||
|
||
# Calculate FLOPS with tf.profiler | ||
run_meta = tf.compat.v1.RunMetadata() | ||
opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation() | ||
flops = tf.compat.v1.profiler.profile( | ||
graph=frozen_func.graph, run_meta=run_meta, cmd="scope", options=opts | ||
) | ||
|
||
# Writing to CSV | ||
write_to_csv(output_path, [num_params, model_size_mb, flops.total_float_ops]) | ||
|
||
def calculateSavedModel(model_path, output_path, input_shape=None): | ||
# Load the model for inference compatibility | ||
model = tf.keras.layers.TFSMLayer(model_path, call_endpoint='serving_default') | ||
|
||
#### MODEL SIZE (# parameters) | ||
num_params = int(sum(tf.reduce_prod(var.shape) for var in model.trainable_variables)) | ||
|
||
#### MODEL FILE SIZE | ||
model_size_bytes = sum(os.path.getsize(os.path.join(dirpath, filename)) | ||
for dirpath, _, filenames in os.walk(model_path) | ||
for filename in filenames) | ||
model_size_mb = model_size_bytes / (1024 * 1024) | ||
|
||
#### FLOPS Calculation | ||
# Use a default input shape if none is provided | ||
if input_shape is None: | ||
print("No input shape provided. Defaulting to image model input shape: (1, 224, 224, 3)") | ||
input_shape = (1, 224, 224, 3) | ||
|
||
# Convert the model to a frozen graph for profiling | ||
concrete_func = tf.function(lambda x: model(x)).get_concrete_function( | ||
tf.TensorSpec(input_shape, model.dtype) | ||
) | ||
frozen_func, _ = convert_variables_to_constants_v2_as_graph(concrete_func) | ||
|
||
# Now that we have the frozen function, we can proceed with profiling | ||
with tf.compat.v1.Session(graph=frozen_func.graph) as sess: | ||
# Run metadata collection for FLOPS and parameter profiling | ||
run_meta = tf.compat.v1.RunMetadata() | ||
opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation() | ||
flops = tf.compat.v1.profiler.profile( | ||
graph=sess.graph, | ||
run_meta=run_meta, | ||
cmd="op", | ||
options=opts | ||
) | ||
|
||
# Ensure the profiler has been applied correctly | ||
opts_params = tf.compat.v1.profiler.ProfileOptionBuilder.trainable_variables_parameter() | ||
params = tf.compat.v1.profiler.profile( | ||
graph=sess.graph, | ||
run_meta=run_meta, | ||
cmd="op", | ||
options=opts_params | ||
) | ||
|
||
#### Write Results to CSV | ||
write_to_csv(output_path, [num_params, model_size_mb, flops.total_float_ops]) | ||
|
||
def calculatePyTorch(model_path, output_path, input_shape=None): | ||
model = torch.load(model_path) | ||
model.eval() | ||
|
||
#### MODEL SIZE (# parameters) | ||
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | ||
|
||
#### MODEL FILE SIZE | ||
model_size_bytes = os.path.getsize(model_path) | ||
model_size_mb = model_size_bytes / (1024 * 1024) | ||
|
||
#### FLOPS for PyTorch model | ||
# Check if the user provided an input shape, otherwise default to (1, 3, 224, 224) for image models | ||
if input_shape is None: | ||
print("No input shape provided. Defaulting to image model input shape: (1, 3, 224, 224)") | ||
input_shape = (1, 3, 224, 224) | ||
|
||
try: | ||
# Create a dummy input tensor with the specified shape | ||
if len(input_shape) == 2: # For language models, we assume input shape like (batch_size, sequence_length) | ||
input_tensor = torch.randint(0, 30522, input_shape, dtype=torch.long) | ||
else: | ||
input_tensor = torch.randn(input_shape) | ||
flops = FlopCountAnalysis(model, input_tensor).total() | ||
except Exception as e: | ||
print(f"Error calculating FLOPS: {e}") | ||
flops = "N/A" | ||
|
||
# Writing to CSV | ||
write_to_csv(output_path, [num_params, model_size_mb, flops]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import os | ||
from plugin_interface import PluginInterface | ||
from calculator import calculateH5, calculateSavedModel, calculatePyTorch | ||
|
||
class PlugIn(PluginInterface): | ||
def generate_output(self, model_path, output_directory, filename): | ||
# Create the full path for the output file | ||
output_path = os.path.join(output_directory, f"{filename}.csv") | ||
|
||
# Detect model type based on file extension or directory | ||
if model_path.endswith('.h5'): | ||
calculateH5(model_path, output_path) | ||
|
||
elif os.path.isdir(model_path): # SavedModel format is a directory | ||
input_shape = self._get_input_shape("Enter the input shape (e.g.: 1,224,224,3), or press Enter to default: ") | ||
calculateSavedModel(model_path, output_path, input_shape) | ||
|
||
elif model_path.endswith(('.pt', '.pth')): | ||
input_shape = self._get_input_shape("Enter the input shape (e.g.: 1,3,224,224), or press Enter to default: ") | ||
calculatePyTorch(model_path, output_path, input_shape) | ||
|
||
else: | ||
raise ValueError("Unsupported model format. Please provide a .h5, .pt, .pth file or a SavedModel directory.") | ||
|
||
print(f"Output file '{filename}' generated in '{output_directory}'") | ||
|
||
def _get_input_shape(self, prompt): | ||
# Prompts the user to enter an input shape and converts it to a tuple or returns None if no input is provided. | ||
# Converts 'None' strings to actual None values for flexibility. | ||
|
||
input_shape_str = input(prompt) | ||
if not input_shape_str: | ||
return None | ||
return tuple(int(dim) if dim != "None" else None for dim in input_shape_str.split(',')) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from gaissaplugin import PlugIn | ||
import os | ||
|
||
def main(): | ||
|
||
# Get user input for model path and output directory and filename | ||
model_path = input("Enter the model's file path: ") | ||
|
||
default_output_directory = "./plugin_output" | ||
output_directory = input(f"Enter the output directory (or continue to default: {default_output_directory}): ") or default_output_directory | ||
|
||
model_name = os.path.splitext(os.path.basename(model_path))[0] | ||
default_filename = f"{model_name}_output" | ||
filename = input(f"Enter the filename (or continue to default: {default_filename}): ") or default_filename | ||
|
||
# Instantiate the plugin | ||
plugin = PlugIn() | ||
|
||
# Call the plugin to generate the output | ||
plugin.generate_output(model_path, output_directory, filename) | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# plugin_interface.py | ||
|
||
class PluginInterface: | ||
def generate_output(self, output_directory, filename): | ||
raise NotImplementedError("Subclasses must implement this method") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
tensorflow>=2.0,<3.0 | ||
torch>=2.0.0 | ||
torchvision>=0.15.1 | ||
fvcore>=0.1.5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# GAISSALabel Plug-in | ||
|
||
We offer you a plug-in that runs on the terminal of your server. It will help you generate a file with the configuration parameters of your model. Specifically, it will provide the size of the model, the size of its file, and the FLOPS. | ||
|
||
|
||
## Installation and Usage | ||
|
||
1. Decompress the downloaded file. | ||
2. Access your console and navigate to the decompressed folder. | ||
3. Create a Python virtual environment using the command: | ||
```sh | ||
python -m venv env | ||
``` | ||
4. Activate the virtual environment: | ||
```sh | ||
source env/bin/activate | ||
``` | ||
5. Install the requirements: | ||
```sh | ||
pip install -r requirements.txt | ||
``` | ||
6. Run the script and follow the instructions of the plug-in: | ||
```sh | ||
python main_script.py | ||
``` | ||
|
||
## Updating the GAISSALabel Plug-in | ||
If the GAISSALabel Plug-in is modified, ensure that the new version of the ZIP file is copied to the public folder of the frontend Vue project with the name "GAISSALabel_plugin.zip", so users can access the updated plug-in through the download link. |