Skip to content

Commit

Permalink
Merge pull request #99 from GAISSA-UPC/pluginNewFormats
Browse files Browse the repository at this point in the history
Plugin new formats
  • Loading branch information
pol-33 authored Nov 5, 2024
2 parents 6f4a2d7 + 6eb7c51 commit ac678f9
Show file tree
Hide file tree
Showing 10 changed files with 227 additions and 3 deletions.
4 changes: 2 additions & 2 deletions backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ absl-py==2.0.0
arrow==1.3.0
asgiref==3.7.2
astunparse==1.6.3
backports.zoneinfo==0.2.1
backports.zoneinfo==0.2.1;python_version<"3.9"
cachetools==5.3.2
certifi==2023.5.7
charset-normalizer==3.2.0
Expand Down Expand Up @@ -86,7 +86,7 @@ six==1.16.0
sqlparse==0.4.4
tensorboard==2.13.0
tensorboard-data-server==0.7.2
tensorflow==2.13.1
tensorflow==2.13.0
tensorflow-estimator==2.13.0
tensorflow-io-gcs-filesystem==0.34.0
termcolor==2.3.0
Expand Down
Binary file added frontend/public/GAISSALabel_plugin.zip
Binary file not shown.
2 changes: 1 addition & 1 deletion frontend/src/components/FileNewExperiment.vue
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
<p style="font-weight: bold">{{ $t('GAISSALabel plug-in') }}</p>
<p>{{ $t('We offer you a plug-in that runs on the terminal of your server. It will help you generate a file with the \
configuration parameters of your model. Specifically, it will provide the size of the model, the size of its file and the FLOPS.') }}</p>
<p>{{ $t('You can download it through this') }} <a href="https://drive.google.com/file/d/1oeYVYamzdtR6t4SV8U8Rd-9g-jm-u_ZU/view?usp=sharing" target="_blank">{{ $t('link') }}</a>.</p>
<p>{{ $t('You can download it through this') }} <a href="/GAISSALabel_plugin.zip" target="_blank">{{ $t('link') }}</a>.</p>
<p>{{ $t('Then, follow these steps:') }}</p>
<ol style="margin-left: 30px">
<li><p>{{ $t('Decompress the downloaded file') }}</p></li>
Expand Down
Binary file added plugin/GAISSALabel_plugin.zip
Binary file not shown.
130 changes: 130 additions & 0 deletions plugin/GAISSALabel_plugin/calculator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import os, csv
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph
import torch
from fvcore.nn import FlopCountAnalysis

def write_to_csv(output_path, values):
# Ensure the directory exists
os.makedirs(os.path.dirname(output_path), exist_ok=True)

with open(output_path, 'w', newline='') as file:
writer = csv.writer(file)

# Write the header
writer.writerow(['model_size', 'file_size', 'flops'])

# Write the values
writer.writerow(values)

def calculateH5(model_path, output_path):
model = load_model(model_path)

#### MODEL SIZE (# parameters)
# Get number of parameters
num_params = model.count_params()

#### MODEL FILE SIZE
model_size_bytes = os.path.getsize(model_path)
model_size_mb = model_size_bytes / (1024 * 1024)

#### FLOPS
# convert tf.keras model into frozen graph to count FLOPS about operations used at inference
# FLOPS depends on batch size
inputs = [
tf.TensorSpec([1] + (list(inp.shape[1:]) if not isinstance(inp.shape[1:], list) else inp.shape[1:]), inp.dtype)
for inp in model.inputs
]
real_model = tf.function(model).get_concrete_function(inputs)
frozen_func, _ = convert_variables_to_constants_v2_as_graph(real_model)

# Calculate FLOPS with tf.profiler
run_meta = tf.compat.v1.RunMetadata()
opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
flops = tf.compat.v1.profiler.profile(
graph=frozen_func.graph, run_meta=run_meta, cmd="scope", options=opts
)

# Writing to CSV
write_to_csv(output_path, [num_params, model_size_mb, flops.total_float_ops])

def calculateSavedModel(model_path, output_path, input_shape=None):
# Load the model for inference compatibility
model = tf.keras.layers.TFSMLayer(model_path, call_endpoint='serving_default')

#### MODEL SIZE (# parameters)
num_params = int(sum(tf.reduce_prod(var.shape) for var in model.trainable_variables))

#### MODEL FILE SIZE
model_size_bytes = sum(os.path.getsize(os.path.join(dirpath, filename))
for dirpath, _, filenames in os.walk(model_path)
for filename in filenames)
model_size_mb = model_size_bytes / (1024 * 1024)

#### FLOPS Calculation
# Use a default input shape if none is provided
if input_shape is None:
print("No input shape provided. Defaulting to image model input shape: (1, 224, 224, 3)")
input_shape = (1, 224, 224, 3)

# Convert the model to a frozen graph for profiling
concrete_func = tf.function(lambda x: model(x)).get_concrete_function(
tf.TensorSpec(input_shape, model.dtype)
)
frozen_func, _ = convert_variables_to_constants_v2_as_graph(concrete_func)

# Now that we have the frozen function, we can proceed with profiling
with tf.compat.v1.Session(graph=frozen_func.graph) as sess:
# Run metadata collection for FLOPS and parameter profiling
run_meta = tf.compat.v1.RunMetadata()
opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
flops = tf.compat.v1.profiler.profile(
graph=sess.graph,
run_meta=run_meta,
cmd="op",
options=opts
)

# Ensure the profiler has been applied correctly
opts_params = tf.compat.v1.profiler.ProfileOptionBuilder.trainable_variables_parameter()
params = tf.compat.v1.profiler.profile(
graph=sess.graph,
run_meta=run_meta,
cmd="op",
options=opts_params
)

#### Write Results to CSV
write_to_csv(output_path, [num_params, model_size_mb, flops.total_float_ops])

def calculatePyTorch(model_path, output_path, input_shape=None):
model = torch.load(model_path)
model.eval()

#### MODEL SIZE (# parameters)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

#### MODEL FILE SIZE
model_size_bytes = os.path.getsize(model_path)
model_size_mb = model_size_bytes / (1024 * 1024)

#### FLOPS for PyTorch model
# Check if the user provided an input shape, otherwise default to (1, 3, 224, 224) for image models
if input_shape is None:
print("No input shape provided. Defaulting to image model input shape: (1, 3, 224, 224)")
input_shape = (1, 3, 224, 224)

try:
# Create a dummy input tensor with the specified shape
if len(input_shape) == 2: # For language models, we assume input shape like (batch_size, sequence_length)
input_tensor = torch.randint(0, 30522, input_shape, dtype=torch.long)
else:
input_tensor = torch.randn(input_shape)
flops = FlopCountAnalysis(model, input_tensor).total()
except Exception as e:
print(f"Error calculating FLOPS: {e}")
flops = "N/A"

# Writing to CSV
write_to_csv(output_path, [num_params, model_size_mb, flops])
34 changes: 34 additions & 0 deletions plugin/GAISSALabel_plugin/gaissaplugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os
from plugin_interface import PluginInterface
from calculator import calculateH5, calculateSavedModel, calculatePyTorch

class PlugIn(PluginInterface):
def generate_output(self, model_path, output_directory, filename):
# Create the full path for the output file
output_path = os.path.join(output_directory, f"{filename}.csv")

# Detect model type based on file extension or directory
if model_path.endswith('.h5'):
calculateH5(model_path, output_path)

elif os.path.isdir(model_path): # SavedModel format is a directory
input_shape = self._get_input_shape("Enter the input shape (e.g.: 1,224,224,3), or press Enter to default: ")
calculateSavedModel(model_path, output_path, input_shape)

elif model_path.endswith(('.pt', '.pth')):
input_shape = self._get_input_shape("Enter the input shape (e.g.: 1,3,224,224), or press Enter to default: ")
calculatePyTorch(model_path, output_path, input_shape)

else:
raise ValueError("Unsupported model format. Please provide a .h5, .pt, .pth file or a SavedModel directory.")

print(f"Output file '{filename}' generated in '{output_directory}'")

def _get_input_shape(self, prompt):
# Prompts the user to enter an input shape and converts it to a tuple or returns None if no input is provided.
# Converts 'None' strings to actual None values for flexibility.

input_shape_str = input(prompt)
if not input_shape_str:
return None
return tuple(int(dim) if dim != "None" else None for dim in input_shape_str.split(','))
23 changes: 23 additions & 0 deletions plugin/GAISSALabel_plugin/main_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from gaissaplugin import PlugIn
import os

def main():

# Get user input for model path and output directory and filename
model_path = input("Enter the model's file path: ")

default_output_directory = "./plugin_output"
output_directory = input(f"Enter the output directory (or continue to default: {default_output_directory}): ") or default_output_directory

model_name = os.path.splitext(os.path.basename(model_path))[0]
default_filename = f"{model_name}_output"
filename = input(f"Enter the filename (or continue to default: {default_filename}): ") or default_filename

# Instantiate the plugin
plugin = PlugIn()

# Call the plugin to generate the output
plugin.generate_output(model_path, output_directory, filename)

if __name__ == "__main__":
main()
5 changes: 5 additions & 0 deletions plugin/GAISSALabel_plugin/plugin_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# plugin_interface.py

class PluginInterface:
def generate_output(self, output_directory, filename):
raise NotImplementedError("Subclasses must implement this method")
4 changes: 4 additions & 0 deletions plugin/GAISSALabel_plugin/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
tensorflow>=2.0,<3.0
torch>=2.0.0
torchvision>=0.15.1
fvcore>=0.1.5
28 changes: 28 additions & 0 deletions plugin/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# GAISSALabel Plug-in

We offer you a plug-in that runs on the terminal of your server. It will help you generate a file with the configuration parameters of your model. Specifically, it will provide the size of the model, the size of its file, and the FLOPS.


## Installation and Usage

1. Decompress the downloaded file.
2. Access your console and navigate to the decompressed folder.
3. Create a Python virtual environment using the command:
```sh
python -m venv env
```
4. Activate the virtual environment:
```sh
source env/bin/activate
```
5. Install the requirements:
```sh
pip install -r requirements.txt
```
6. Run the script and follow the instructions of the plug-in:
```sh
python main_script.py
```

## Updating the GAISSALabel Plug-in
If the GAISSALabel Plug-in is modified, ensure that the new version of the ZIP file is copied to the public folder of the frontend Vue project with the name "GAISSALabel_plugin.zip", so users can access the updated plug-in through the download link.

0 comments on commit ac678f9

Please sign in to comment.