Skip to content

Commit

Permalink
add feature-extraction input generator
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Dec 5, 2023
1 parent 078debf commit 430473e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
20 changes: 10 additions & 10 deletions optimum_benchmark/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,21 @@ def extract_shapes_from_model_artifacts(
artifacts_dict.update(processor_dict)

# text input
shapes["vocab_size"] = artifacts_dict.get("vocab_size", 2)
if shapes["vocab_size"] == 0:
shapes["vocab_size"] = artifacts_dict.get("vocab_size", None)
if shapes["vocab_size"] is None or shapes["vocab_size"] == 0:
shapes["vocab_size"] = 2

shapes["type_vocab_size"] = artifacts_dict.get("type_vocab_size", 2)
if shapes["type_vocab_size"] == 0:
shapes["type_vocab_size"] = artifacts_dict.get("type_vocab_size", None)
if shapes["type_vocab_size"] is None or shapes["type_vocab_size"] == 0:
shapes["type_vocab_size"] = 2

shapes["max_position_embeddings"] = artifacts_dict.get("max_position_embeddings", 2)
if shapes["max_position_embeddings"] == 0:
shapes["max_position_embeddings"] = artifacts_dict.get("max_position_embeddings", None)
if shapes["max_position_embeddings"] is None or shapes["max_position_embeddings"] == 0:
shapes["max_position_embeddings"] = 2

# image input
shapes["num_channels"] = artifacts_dict.get("num_channels", None)
if shapes["num_channels"] is None:
if shapes["num_channels"] is None or shapes["num_channels"] == 0:
# processors have different names for the number of channels
shapes["num_channels"] = artifacts_dict.get("channels", None)

Expand All @@ -90,14 +90,14 @@ def extract_shapes_from_model_artifacts(
shapes["height"] = None
shapes["width"] = None

# classification labels (default to 2)
# classification labels
id2label = artifacts_dict.get("id2label", {"0": "LABEL_0", "1": "LABEL_1"})
shapes["num_labels"] = len(id2label)
if shapes["num_labels"] == 0:
shapes["num_labels"] = 2

# object detection labels (default to 2)
shapes["num_queries"] = artifacts_dict.get("num_queries", 2)
# object detection labels
shapes["num_queries"] = artifacts_dict.get("num_queries", None)
if shapes["num_queries"] == 0:
shapes["num_queries"] = 2

Expand Down
16 changes: 16 additions & 0 deletions optimum_benchmark/generators/task_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,24 @@ def generate(self):
return dummy


class FeatureExtractionGenerator(TextGenerator, ImageGenerator):
def generate(self):
dummy = {}

if self.shapes["num_channels"] is not None and self.shapes["height"] is not None:
dummy["pixel_values"] = self.pixel_values()
else:
dummy["input_ids"] = self.input_ids()
dummy["attention_mask"] = self.attention_mask()
dummy["token_type_ids"] = self.token_type_ids()
dummy["position_ids"] = self.position_ids()

return dummy


TASKS_TO_GENERATORS = {
# model tasks
"feature-extraction": FeatureExtractionGenerator,
"text-classification": TextClassificationGenerator,
"token-classification": TokenClassificationGenerator,
"text-generation": TextGenerationGenerator,
Expand Down

0 comments on commit 430473e

Please sign in to comment.