Skip to content
This repository has been archived by the owner on Nov 23, 2024. It is now read-only.

Commit

Permalink
Merge branch 'add-mnist-datasets' of https://github.com/Safe-DS/Datasets
Browse files Browse the repository at this point in the history
 into add-mnist-datasets

# Conflicts:
#	src/safeds_datasets/image/_mnist/_mnist.py
  • Loading branch information
Marsmaennchen221 committed May 8, 2024
2 parents c674e14 + fda72fc commit d066622
Showing 1 changed file with 59 additions and 11 deletions.
70 changes: 59 additions & 11 deletions src/safeds_datasets/image/_mnist/_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,34 @@

_fashion_mnist_links: list[str] = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"]
_fashion_mnist_files: dict[str, str] = _mnist_files
_fashion_mnist_labels: dict[int, str] = {0: "T-shirt/top", 1: "Trouser", 2: "Pullover", 3: "Dress", 4: "Coat", 5: "Sandal", 6: "Shirt", 7: "Sneaker", 8: "Bag", 9: "Ankle boot"}
_fashion_mnist_labels: dict[int, str] = {
0: "T-shirt/top",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle boot",
}
_fashion_mnist_folder: str = "fashion-mnist"

_kuzushiji_mnist_links: list[str] = ["http://codh.rois.ac.jp/kmnist/dataset/kmnist/"]
_kuzushiji_mnist_files: dict[str, str] = _mnist_files
_kuzushiji_mnist_labels: dict[int, str] = {0: "\u304a", 1: "\u304d", 2: "\u3059", 3: "\u3064", 4: "\u306a", 5: "\u306f", 6: "\u307e", 7: "\u3084", 8: "\u308c", 9: "\u3092"}
_kuzushiji_mnist_labels: dict[int, str] = {
0: "\u304a",
1: "\u304d",
2: "\u3059",
3: "\u3064",
4: "\u306a",
5: "\u306f",
6: "\u307e",
7: "\u3084",
8: "\u308c",
9: "\u3092",
}
_kuzushiji_mnist_folder: str = "kmnist"


Expand Down Expand Up @@ -68,7 +90,11 @@ def load_mnist(path: str | Path, download: bool = True) -> tuple[ImageDataset[Co
missing_files.append(file_path)
if len(missing_files) > 0:
if download:
_download_mnist_like(path, {name: f_path for name, f_path in _mnist_files.items() if f_path in missing_files}, _mnist_links)
_download_mnist_like(
path,
{name: f_path for name, f_path in _mnist_files.items() if f_path in missing_files},
_mnist_links,
)
else:
raise FileNotFoundError(f"Could not find files {[str(path / file) for file in missing_files]}")
return _load_mnist_like(path, _mnist_files, _mnist_labels)
Expand Down Expand Up @@ -104,7 +130,11 @@ def load_fashion_mnist(path: str | Path, download: bool = True) -> tuple[ImageDa
missing_files.append(file_path)
if len(missing_files) > 0:
if download:
_download_mnist_like(path, {name: f_path for name, f_path in _fashion_mnist_files.items() if f_path in missing_files}, _fashion_mnist_links)
_download_mnist_like(
path,
{name: f_path for name, f_path in _fashion_mnist_files.items() if f_path in missing_files},
_fashion_mnist_links,
)
else:
raise FileNotFoundError(f"Could not find files {[str(path / file) for file in missing_files]}")
return _load_mnist_like(path, _fashion_mnist_files, _fashion_mnist_labels)
Expand Down Expand Up @@ -140,13 +170,21 @@ def load_kmnist(path: str | Path, download: bool = True) -> tuple[ImageDataset[C
missing_files.append(file_path)
if len(missing_files) > 0:
if download:
_download_mnist_like(path, {name: f_path for name, f_path in _kuzushiji_mnist_files.items() if f_path in missing_files}, _kuzushiji_mnist_links)
_download_mnist_like(
path,
{name: f_path for name, f_path in _kuzushiji_mnist_files.items() if f_path in missing_files},
_kuzushiji_mnist_links,
)
else:
raise FileNotFoundError(f"Could not find files {[str(path / file) for file in missing_files]}")
return _load_mnist_like(path, _kuzushiji_mnist_files, _kuzushiji_mnist_labels)


def _load_mnist_like(path: str | Path, files: dict[str, str], labels: dict[int, str]) -> tuple[ImageDataset[Column], ImageDataset[Column]]:
def _load_mnist_like(
path: str | Path,
files: dict[str, str],
labels: dict[int, str],
) -> tuple[ImageDataset[Column], ImageDataset[Column]]:
_init_default_device()

path = Path(path)
Expand All @@ -156,23 +194,29 @@ def _load_mnist_like(path: str | Path, files: dict[str, str], labels: dict[int,
train_image_list: ImageList | None = None
for file_name, file_path in files.items():
if "idx1" in file_name:
with gzip.open(path / file_path, mode='rb') as label_file:
with gzip.open(path / file_path, mode="rb") as label_file:
magic, size = struct.unpack(">II", label_file.read(8))
if magic != 2049:
raise ValueError(f"Magic number mismatch. Actual {magic} != Expected 2049.") # pragma: no cover
if "train" in file_name:
train_labels = Column(file_name, [labels[label_index] for label_index in array("B", label_file.read())])
train_labels = Column(
file_name,
[labels[label_index] for label_index in array("B", label_file.read())],
)
else:
test_labels = Column(file_name, [labels[label_index] for label_index in array("B", label_file.read())])
else:
with gzip.open(path / file_path, mode='rb') as image_file:
with gzip.open(path / file_path, mode="rb") as image_file:
magic, size, rows, cols = struct.unpack(">IIII", image_file.read(16))
if magic != 2051:
raise ValueError(f"Magic number mismatch. Actual {magic} != Expected 2051.") # pragma: no cover
image_data = array("B", image_file.read())
image_tensor = torch.empty(size, 1, rows, cols)
for i in range(size):
image_tensor[i, 0] = torch.frombuffer(image_data[i * rows * cols:(i + 1) * rows * cols], dtype=torch.uint8).reshape(rows, cols)
image_tensor[i, 0] = torch.frombuffer(
image_data[i * rows * cols : (i + 1) * rows * cols],
dtype=torch.uint8,
).reshape(rows, cols)
image_list = _SingleSizeImageList()
image_list._tensor = image_tensor
image_list._tensor_positions_to_indices = list(range(size))
Expand All @@ -183,7 +227,11 @@ def _load_mnist_like(path: str | Path, files: dict[str, str], labels: dict[int,
test_image_list = image_list
if train_image_list is None or test_image_list is None or train_labels is None or test_labels is None:
raise ValueError # pragma: no cover
return ImageDataset[Column](train_image_list, train_labels, 32, shuffle=True), ImageDataset[Column](test_image_list, test_labels, 32)
return ImageDataset[Column](train_image_list, train_labels, 32, shuffle=True), ImageDataset[Column](
test_image_list,
test_labels,
32,
)


def _download_mnist_like(path: str | Path, files: dict[str, str], links: list[str]) -> None:
Expand Down

0 comments on commit d066622

Please sign in to comment.