Skip to content
This repository has been archived by the owner on Nov 23, 2024. It is now read-only.

Commit

Permalink
test: added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Marsmaennchen221 committed May 8, 2024
1 parent d87f65a commit c674e14
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 6 deletions.
11 changes: 5 additions & 6 deletions src/safeds_datasets/image/_mnist/_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
import struct
import sys
import tempfile
import urllib.request
from array import array
from pathlib import Path
Expand Down Expand Up @@ -160,16 +159,16 @@ def _load_mnist_like(path: str | Path, files: dict[str, str], labels: dict[int,
with gzip.open(path / file_path, mode='rb') as label_file:
magic, size = struct.unpack(">II", label_file.read(8))
if magic != 2049:
raise ValueError(f"Magic number mismatch. Actual {magic} != Expected 2049.")
raise ValueError(f"Magic number mismatch. Actual {magic} != Expected 2049.") # pragma: no cover
if "train" in file_name:
train_labels = Column(file_name, [labels[label_index] for label_index in array("B", label_file.read())])
else:
test_labels = Column(file_name, array("B", label_file.read()))
test_labels = Column(file_name, [labels[label_index] for label_index in array("B", label_file.read())])
else:
with gzip.open(path / file_path, mode='rb') as image_file:
magic, size, rows, cols = struct.unpack(">IIII", image_file.read(16))
if magic != 2051:
raise ValueError(f"Magic number mismatch. Actual {magic} != Expected 2051.")
raise ValueError(f"Magic number mismatch. Actual {magic} != Expected 2051.") # pragma: no cover
image_data = array("B", image_file.read())
image_tensor = torch.empty(size, 1, rows, cols)
for i in range(size):
Expand All @@ -183,7 +182,7 @@ def _load_mnist_like(path: str | Path, files: dict[str, str], labels: dict[int,
else:
test_image_list = image_list
if train_image_list is None or test_image_list is None or train_labels is None or test_labels is None:
raise ValueError
raise ValueError # pragma: no cover
return ImageDataset[Column](train_image_list, train_labels, 32, shuffle=True), ImageDataset[Column](test_image_list, test_labels, 32)


Expand All @@ -197,7 +196,7 @@ def _download_mnist_like(path: str | Path, files: dict[str, str], links: list[st
print() # noqa: T201
break
except HTTPError as e:
print(f"An error occurred while downloading: {e}") # noqa: T201
print(f"An error occurred while downloading: {e}") # noqa: T201 # pragma: no cover


def _report_download_progress(current_packages: int, package_size: int, file_size: int) -> None:
Expand Down
Empty file.
Empty file.
74 changes: 74 additions & 0 deletions tests/safeds_datasets/image/_mnist/test_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import os
import tempfile
from pathlib import Path

import pytest
from safeds.data.labeled.containers import ImageDataset

from safeds_datasets.image import load_mnist, _mnist, load_fashion_mnist, load_kmnist


class TestMNIST:

def test_should_download_and_return_mnist(self):
with tempfile.TemporaryDirectory() as tmpdirname:
train, test = load_mnist(tmpdirname, True)
files = os.listdir(Path(tmpdirname) / _mnist._mnist._mnist_folder)
for mnist_file in _mnist._mnist._mnist_files.values():
assert mnist_file in files
assert isinstance(train, ImageDataset)
assert isinstance(test, ImageDataset)
assert len(train) == 60_000
assert len(test) == 10_000
train_output = train.get_output()
test_output = test.get_output()
assert set(train_output.get_unique_values()) == set(test_output.get_unique_values()) == set(_mnist._mnist._mnist_labels.values())

def test_should_raise_if_file_not_found(self):
with tempfile.TemporaryDirectory() as tmpdirname:
with pytest.raises(FileNotFoundError):
load_mnist(tmpdirname, False)


class TestFashionMNIST:

def test_should_download_and_return_mnist(self):
with tempfile.TemporaryDirectory() as tmpdirname:
train, test = load_fashion_mnist(tmpdirname, True)
files = os.listdir(Path(tmpdirname) / _mnist._mnist._fashion_mnist_folder)
for mnist_file in _mnist._mnist._fashion_mnist_files.values():
assert mnist_file in files
assert isinstance(train, ImageDataset)
assert isinstance(test, ImageDataset)
assert len(train) == 60_000
assert len(test) == 10_000
train_output = train.get_output()
test_output = test.get_output()
assert set(train_output.get_unique_values()) == set(test_output.get_unique_values()) == set(_mnist._mnist._fashion_mnist_labels.values())

def test_should_raise_if_file_not_found(self):
with tempfile.TemporaryDirectory() as tmpdirname:
with pytest.raises(FileNotFoundError):
load_fashion_mnist(tmpdirname, False)


class TestKMNIST:

def test_should_download_and_return_mnist(self):
with tempfile.TemporaryDirectory() as tmpdirname:
train, test = load_kmnist(tmpdirname, True)
files = os.listdir(Path(tmpdirname) / _mnist._mnist._kuzushiji_mnist_folder)
for mnist_file in _mnist._mnist._kuzushiji_mnist_files.values():
assert mnist_file in files
assert isinstance(train, ImageDataset)
assert isinstance(test, ImageDataset)
assert len(train) == 60_000
assert len(test) == 10_000
train_output = train.get_output()
test_output = test.get_output()
assert set(train_output.get_unique_values()) == set(test_output.get_unique_values()) == set(_mnist._mnist._kuzushiji_mnist_labels.values())

def test_should_raise_if_file_not_found(self):
with tempfile.TemporaryDirectory() as tmpdirname:
with pytest.raises(FileNotFoundError):
load_kmnist(tmpdirname, False)

0 comments on commit c674e14

Please sign in to comment.