From 430473ef071f0b476387067260537f5110b782d4 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Tue, 5 Dec 2023 10:21:52 +0000 Subject: [PATCH] add feature-extraction input generator --- optimum_benchmark/backends/utils.py | 20 +++++++++---------- .../generators/task_generator.py | 16 +++++++++++++++ 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/optimum_benchmark/backends/utils.py b/optimum_benchmark/backends/utils.py index 80ca1b18..d9606b42 100644 --- a/optimum_benchmark/backends/utils.py +++ b/optimum_benchmark/backends/utils.py @@ -51,21 +51,21 @@ def extract_shapes_from_model_artifacts( artifacts_dict.update(processor_dict) # text input - shapes["vocab_size"] = artifacts_dict.get("vocab_size", 2) - if shapes["vocab_size"] == 0: + shapes["vocab_size"] = artifacts_dict.get("vocab_size", None) + if shapes["vocab_size"] is None or shapes["vocab_size"] == 0: shapes["vocab_size"] = 2 - shapes["type_vocab_size"] = artifacts_dict.get("type_vocab_size", 2) - if shapes["type_vocab_size"] == 0: + shapes["type_vocab_size"] = artifacts_dict.get("type_vocab_size", None) + if shapes["type_vocab_size"] is None or shapes["type_vocab_size"] == 0: shapes["type_vocab_size"] = 2 - shapes["max_position_embeddings"] = artifacts_dict.get("max_position_embeddings", 2) - if shapes["max_position_embeddings"] == 0: + shapes["max_position_embeddings"] = artifacts_dict.get("max_position_embeddings", None) + if shapes["max_position_embeddings"] is None or shapes["max_position_embeddings"] == 0: shapes["max_position_embeddings"] = 2 # image input shapes["num_channels"] = artifacts_dict.get("num_channels", None) - if shapes["num_channels"] is None: + if shapes["num_channels"] is None or shapes["num_channels"] == 0: # processors have different names for the number of channels shapes["num_channels"] = artifacts_dict.get("channels", None) @@ -90,14 +90,14 @@ def extract_shapes_from_model_artifacts( shapes["height"] = None shapes["width"] = None - # classification labels (default to 2) + # classification labels id2label = artifacts_dict.get("id2label", {"0": "LABEL_0", "1": "LABEL_1"}) shapes["num_labels"] = len(id2label) if shapes["num_labels"] == 0: shapes["num_labels"] = 2 - # object detection labels (default to 2) - shapes["num_queries"] = artifacts_dict.get("num_queries", 2) + # object detection labels + shapes["num_queries"] = artifacts_dict.get("num_queries", None) if shapes["num_queries"] == 0: shapes["num_queries"] = 2 diff --git a/optimum_benchmark/generators/task_generator.py b/optimum_benchmark/generators/task_generator.py index 2c0f8a68..adf5d61d 100644 --- a/optimum_benchmark/generators/task_generator.py +++ b/optimum_benchmark/generators/task_generator.py @@ -370,8 +370,24 @@ def generate(self): return dummy +class FeatureExtractionGenerator(TextGenerator, ImageGenerator): + def generate(self): + dummy = {} + + if self.shapes["num_channels"] is not None and self.shapes["height"] is not None: + dummy["pixel_values"] = self.pixel_values() + else: + dummy["input_ids"] = self.input_ids() + dummy["attention_mask"] = self.attention_mask() + dummy["token_type_ids"] = self.token_type_ids() + dummy["position_ids"] = self.position_ids() + + return dummy + + TASKS_TO_GENERATORS = { # model tasks + "feature-extraction": FeatureExtractionGenerator, "text-classification": TextClassificationGenerator, "token-classification": TokenClassificationGenerator, "text-generation": TextGenerationGenerator,