From c674e1496cbbf26e16b29f9b5447c41f966ebf3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Gr=C3=A9us?= Date: Thu, 9 May 2024 01:12:14 +0200 Subject: [PATCH] test: added tests --- src/safeds_datasets/image/_mnist/_mnist.py | 11 ++- tests/safeds_datasets/image/__init__.py | 0 .../safeds_datasets/image/_mnist/__init__.py | 0 .../image/_mnist/test_mnist.py | 74 +++++++++++++++++++ 4 files changed, 79 insertions(+), 6 deletions(-) create mode 100644 tests/safeds_datasets/image/__init__.py create mode 100644 tests/safeds_datasets/image/_mnist/__init__.py create mode 100644 tests/safeds_datasets/image/_mnist/test_mnist.py diff --git a/src/safeds_datasets/image/_mnist/_mnist.py b/src/safeds_datasets/image/_mnist/_mnist.py index 6ce594e..0ab6f5a 100644 --- a/src/safeds_datasets/image/_mnist/_mnist.py +++ b/src/safeds_datasets/image/_mnist/_mnist.py @@ -2,7 +2,6 @@ import os import struct import sys -import tempfile import urllib.request from array import array from pathlib import Path @@ -160,16 +159,16 @@ def _load_mnist_like(path: str | Path, files: dict[str, str], labels: dict[int, with gzip.open(path / file_path, mode='rb') as label_file: magic, size = struct.unpack(">II", label_file.read(8)) if magic != 2049: - raise ValueError(f"Magic number mismatch. Actual {magic} != Expected 2049.") + raise ValueError(f"Magic number mismatch. Actual {magic} != Expected 2049.") # pragma: no cover if "train" in file_name: train_labels = Column(file_name, [labels[label_index] for label_index in array("B", label_file.read())]) else: - test_labels = Column(file_name, array("B", label_file.read())) + test_labels = Column(file_name, [labels[label_index] for label_index in array("B", label_file.read())]) else: with gzip.open(path / file_path, mode='rb') as image_file: magic, size, rows, cols = struct.unpack(">IIII", image_file.read(16)) if magic != 2051: - raise ValueError(f"Magic number mismatch. Actual {magic} != Expected 2051.") + raise ValueError(f"Magic number mismatch. Actual {magic} != Expected 2051.") # pragma: no cover image_data = array("B", image_file.read()) image_tensor = torch.empty(size, 1, rows, cols) for i in range(size): @@ -183,7 +182,7 @@ def _load_mnist_like(path: str | Path, files: dict[str, str], labels: dict[int, else: test_image_list = image_list if train_image_list is None or test_image_list is None or train_labels is None or test_labels is None: - raise ValueError + raise ValueError # pragma: no cover return ImageDataset[Column](train_image_list, train_labels, 32, shuffle=True), ImageDataset[Column](test_image_list, test_labels, 32) @@ -197,7 +196,7 @@ def _download_mnist_like(path: str | Path, files: dict[str, str], links: list[st print() # noqa: T201 break except HTTPError as e: - print(f"An error occurred while downloading: {e}") # noqa: T201 + print(f"An error occurred while downloading: {e}") # noqa: T201 # pragma: no cover def _report_download_progress(current_packages: int, package_size: int, file_size: int) -> None: diff --git a/tests/safeds_datasets/image/__init__.py b/tests/safeds_datasets/image/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/safeds_datasets/image/_mnist/__init__.py b/tests/safeds_datasets/image/_mnist/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/safeds_datasets/image/_mnist/test_mnist.py b/tests/safeds_datasets/image/_mnist/test_mnist.py new file mode 100644 index 0000000..a436aa8 --- /dev/null +++ b/tests/safeds_datasets/image/_mnist/test_mnist.py @@ -0,0 +1,74 @@ +import os +import tempfile +from pathlib import Path + +import pytest +from safeds.data.labeled.containers import ImageDataset + +from safeds_datasets.image import load_mnist, _mnist, load_fashion_mnist, load_kmnist + + +class TestMNIST: + + def test_should_download_and_return_mnist(self): + with tempfile.TemporaryDirectory() as tmpdirname: + train, test = load_mnist(tmpdirname, True) + files = os.listdir(Path(tmpdirname) / _mnist._mnist._mnist_folder) + for mnist_file in _mnist._mnist._mnist_files.values(): + assert mnist_file in files + assert isinstance(train, ImageDataset) + assert isinstance(test, ImageDataset) + assert len(train) == 60_000 + assert len(test) == 10_000 + train_output = train.get_output() + test_output = test.get_output() + assert set(train_output.get_unique_values()) == set(test_output.get_unique_values()) == set(_mnist._mnist._mnist_labels.values()) + + def test_should_raise_if_file_not_found(self): + with tempfile.TemporaryDirectory() as tmpdirname: + with pytest.raises(FileNotFoundError): + load_mnist(tmpdirname, False) + + +class TestFashionMNIST: + + def test_should_download_and_return_mnist(self): + with tempfile.TemporaryDirectory() as tmpdirname: + train, test = load_fashion_mnist(tmpdirname, True) + files = os.listdir(Path(tmpdirname) / _mnist._mnist._fashion_mnist_folder) + for mnist_file in _mnist._mnist._fashion_mnist_files.values(): + assert mnist_file in files + assert isinstance(train, ImageDataset) + assert isinstance(test, ImageDataset) + assert len(train) == 60_000 + assert len(test) == 10_000 + train_output = train.get_output() + test_output = test.get_output() + assert set(train_output.get_unique_values()) == set(test_output.get_unique_values()) == set(_mnist._mnist._fashion_mnist_labels.values()) + + def test_should_raise_if_file_not_found(self): + with tempfile.TemporaryDirectory() as tmpdirname: + with pytest.raises(FileNotFoundError): + load_fashion_mnist(tmpdirname, False) + + +class TestKMNIST: + + def test_should_download_and_return_mnist(self): + with tempfile.TemporaryDirectory() as tmpdirname: + train, test = load_kmnist(tmpdirname, True) + files = os.listdir(Path(tmpdirname) / _mnist._mnist._kuzushiji_mnist_folder) + for mnist_file in _mnist._mnist._kuzushiji_mnist_files.values(): + assert mnist_file in files + assert isinstance(train, ImageDataset) + assert isinstance(test, ImageDataset) + assert len(train) == 60_000 + assert len(test) == 10_000 + train_output = train.get_output() + test_output = test.get_output() + assert set(train_output.get_unique_values()) == set(test_output.get_unique_values()) == set(_mnist._mnist._kuzushiji_mnist_labels.values()) + + def test_should_raise_if_file_not_found(self): + with tempfile.TemporaryDirectory() as tmpdirname: + with pytest.raises(FileNotFoundError): + load_kmnist(tmpdirname, False)