From c39b68a67c442f34395fbab52cf165fc2229c936 Mon Sep 17 00:00:00 2001 From: Hiten Vidhani <60215053+hitenvidhani@users.noreply.github.com> Date: Fri, 15 Sep 2023 00:07:20 +0530 Subject: [PATCH] Dataset size on CLI (#345) * init * lint * fix * lint * lint * rm print stmt * Update dataset_utils.py --- .../description_dataset_retriever.py | 9 ++- prompt2model/utils/dataset_utils.py | 33 ++++++++++ prompt2model/utils/dataset_utils_test.py | 64 +++++++++++++++++++ 3 files changed, 103 insertions(+), 3 deletions(-) create mode 100644 prompt2model/utils/dataset_utils.py create mode 100644 prompt2model/utils/dataset_utils_test.py diff --git a/prompt2model/dataset_retriever/description_dataset_retriever.py b/prompt2model/dataset_retriever/description_dataset_retriever.py index e3745669c..0f0b4e828 100644 --- a/prompt2model/dataset_retriever/description_dataset_retriever.py +++ b/prompt2model/dataset_retriever/description_dataset_retriever.py @@ -13,6 +13,7 @@ from prompt2model.dataset_retriever.base import DatasetInfo, DatasetRetriever from prompt2model.prompt_parser import PromptSpec from prompt2model.utils import encode_text, retrieve_objects +from prompt2model.utils.dataset_utils import get_dataset_size datasets.utils.logging.disable_progress_bar() logger = logging.getLogger(__name__) @@ -115,10 +116,12 @@ def choose_dataset_by_cli(self, top_datasets: list[DatasetInfo]) -> str | None: """ self._print_divider() print("Here are the datasets I've retrieved for you:") - print("#\tName\tDescription") + print("#\tName\tSize[MB]\tDescription") for i, d in enumerate(top_datasets): - description_no_spaces = d.description.replace("\n", " ") - print(f"{i+1}):\t{d.name}\t{description_no_spaces}") + description_no_space = d.description.replace("\n", " ") + print( + f"{i+1}):\t{d.name}\t{get_dataset_size(d.name)}\t{description_no_space}" + ) self._print_divider() print( diff --git a/prompt2model/utils/dataset_utils.py b/prompt2model/utils/dataset_utils.py new file mode 100644 index 000000000..2f89257c5 --- /dev/null +++ b/prompt2model/utils/dataset_utils.py @@ -0,0 +1,33 @@ +"""Util functions for datasets.""" + +import requests + +from prompt2model.utils.logging_utils import get_formatted_logger + +logger = get_formatted_logger("dataset_utils") + + +def query(API_URL): + """Returns a response json for a URL.""" + try: + response = requests.get(API_URL) + if response.status_code == 200: + return response.json() + else: + logger.error(f"Error occurred in fetching size: {response.status_code}") + except requests.exceptions.RequestException as e: + logger.error("Error occurred in making the request: " + str(e)) + + return {} + + +def get_dataset_size(dataset_name): + """Fetches dataset size for a dataset in MB from hugging face API.""" + API_URL = f"https://datasets-server.huggingface.co/size?dataset={dataset_name}" + data = query(API_URL) + size_dict = data.get("size", {}) + return ( + "NA" + if size_dict is {} + else "{:.2f}".format(size_dict["dataset"]["num_bytes_memory"] / 1024 / 1024) + ) diff --git a/prompt2model/utils/dataset_utils_test.py b/prompt2model/utils/dataset_utils_test.py new file mode 100644 index 000000000..9df62ba39 --- /dev/null +++ b/prompt2model/utils/dataset_utils_test.py @@ -0,0 +1,64 @@ +"""Testing dataset utility functions.""" +from unittest.mock import patch + +from prompt2model.utils import dataset_utils + + +@patch("prompt2model.utils.dataset_utils.query") +def test_get_dataset_size(mock_request): + """Test function for get_dataset_size.""" + mock_request.return_value = { + "size": { + "dataset": { + "dataset": "rotten_tomatoes", + "num_bytes_original_files": 487770, + "num_bytes_parquet_files": 881052, + "num_bytes_memory": 1345449, + "num_rows": 10662, + }, + "configs": [ + { + "dataset": "rotten_tomatoes", + "config": "default", + "num_bytes_original_files": 487770, + "num_bytes_parquet_files": 881052, + "num_bytes_memory": 1345449, + "num_rows": 10662, + "num_columns": 2, + } + ], + "splits": [ + { + "dataset": "rotten_tomatoes", + "config": "default", + "split": "train", + "num_bytes_parquet_files": 698845, + "num_bytes_memory": 1074806, + "num_rows": 8530, + "num_columns": 2, + }, + { + "dataset": "rotten_tomatoes", + "config": "default", + "split": "validation", + "num_bytes_parquet_files": 90001, + "num_bytes_memory": 134675, + "num_rows": 1066, + "num_columns": 2, + }, + { + "dataset": "rotten_tomatoes", + "config": "default", + "split": "test", + "num_bytes_parquet_files": 92206, + "num_bytes_memory": 135968, + "num_rows": 1066, + "num_columns": 2, + }, + ], + }, + "pending": [], + "failed": [], + "partial": False, + } + assert dataset_utils.get_dataset_size("rotten_tomatoes") == "1.28"