Skip to content

Commit

Permalink
MVP2 extract_embeddings.py
Browse files Browse the repository at this point in the history
  • Loading branch information
robogast committed Nov 22, 2021
1 parent a0aee3e commit e8a3bb7
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 69 deletions.
35 changes: 17 additions & 18 deletions datamodules/camelyon16.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ def __init__(
train_frac: float,
**throwaway_kwargs
):
modality_folders = ('images', 'masks')
modality_postfixes = ('', '_mask')
modality_folders = ('images', 'masks', 'tissue_masks')
modality_postfixes = ('', '_mask', '_tissue')

find_pairs = functools.partial(
_find_image_mask_pairs_paths,
Expand All @@ -148,7 +148,7 @@ def __init__(
modality_postfixes=modality_postfixes
)

self.image_paths, self.mask_paths = (
self.image_paths, self.mask_paths, self.tissue_mask_paths = (
find_pairs(pattern='test')
if train == 'test'
else _train_val_split_paths(
Expand Down Expand Up @@ -189,7 +189,7 @@ def _sizes(self) -> np.ndarray: # np.ndarray[..., 2]

@functools.lru_cache(maxsize=1)
def _get_paths(self, index):
return self.image_paths[index], self.mask_paths[index]
return self.image_paths[index], self.mask_paths[index], self.tissue_mask_paths[index]

def __getitem__(self, index) -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
"""
Expand All @@ -209,8 +209,8 @@ def __getitem__(self, index) -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]
patch_index % self._sizes[img_index, 1] # noqa[E222]
))

paths = self._get_paths(img_index)
patch, labels = (
paths = self._get_paths(img_index)[:2] # img, mask, _
patch, label = (
ImageReader(path, self.spacing_tolerance).read(
self.spacing,
*(patch_indices * self.patch_size), # row, col
Expand All @@ -221,9 +221,12 @@ def __getitem__(self, index) -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]
)

return ( # type: ignore
*((patch, labels)
*((patch, label)
if self.transforms is None
else self.transforms(image=patch, mask=labels)).values(),
else (
(transformed := self.transforms(image=patch, mask=label))['image'],
transformed['mask']
)),
(img_index, patch_indices, *paths)
)

Expand All @@ -241,23 +244,19 @@ def _train_val_split_paths(
This function is obviously extremely overengineered.
"""
assert mode in ('train', 'validation')

temp = []
for images, masks in modality_arrays:

for modality_array in modality_arrays:
length, split_index = (
(ln := len(images)),
(ln := len(modality_array[0])),
tf if 0 < (tf := round(ln * split_frac)) < ln
else 1 if tf == 0
else tf - 1
)

temp.append(
[
elem[(slice(split_index) if mode == 'train' else slice(split_index, length))]
for elem in (images, masks)
]
)
temp.append([
elem[(slice(split_index) if mode == 'train' else slice(split_index, length))]
for elem in modality_array
])

return tuple( # type: ignore
map(
Expand Down
2 changes: 2 additions & 0 deletions load-asap.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ module load Miniconda3/4.7.12.1

module load ASAP/8c9a8fb-fosscuda-2020a

source deactivate
source deactivate
source activate 2D-VQ-AE-2

export PYTHONPATH=$PYTHONPATH:~/.conda/envs/2D-VQ-AE-2/lib/python3.8/site-packages/
112 changes: 61 additions & 51 deletions scripts/extract_embeddings/extract_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from __future__ import annotations

import logging
from collections.abc import Iterable, Sequence
import functools
from collections.abc import Iterable
from dataclasses import dataclass
from functools import partial
from glob import glob
from itertools import chain, starmap
from itertools import chain
from operator import attrgetter
from pathlib import Path


import numpy as np
import hydra
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from hydra import compose, initialize_config_dir
from hydra.core.global_hydra import GlobalHydra
from hydra.utils import instantiate
Expand Down Expand Up @@ -43,80 +43,87 @@ def get_encodings(
model: nn.Module,
dataset: Dataset
):
def setdefault_(array_dict: dict, array_count: dict, names, img_idx):
u_info = np.unique(names, return_index=True, return_counts=True)
for unique_name, unique_index, unique_count in zip(*u_info):
if unique_name not in array_dict:
img_index = img_idx[unique_index]
array_dict[unique_name] = np.empty(
patch_size * (dataset._sizes[img_index]),
dtype=encodings.dtype
)
array_count[unique_name] = dataset._lengths[img_index]
return u_info

def pop_array_if_done_(array_dict, array_count, unique_names, unique_counts):
return_dict = {}
for u_name, count in zip(unique_names, unique_counts):
array_count[u_name] -= count
if array_count[u_name] == 0:
array_count.pop(u_name)
return_dict[u_name] = array_dict.pop(u_name)
elif array_count[u_name] < 0:
raise RuntimeError("steps left is less than 0, this shouldn't happen")
return return_dict
@functools.lru_cache(maxsize=1)
def get_slices(patch_idx):
return np.asarray([
np.mgrid[tuple(slice(*map(int, d)) for d in dim)]
for dim in (torch.cat([patch_idx[None], patch_idx[None]+1]).T.swapaxes(0, 1) * patch_size)
])

arrays, counts, patch_size = {}, {}, None
for encodings, idx, names in run_eval(model, dataset):
img_idx, patch_idx = idx[0], idx[1:].T
if patch_size is None:
patch_size = np.asarray(encodings.shape[1:])

u_names, _, u_counts = setdefault_(arrays, counts, names, img_idx)
for ret_values in run_eval(model, dataset):
for (encodings, names, img_idx, patch_idx) in ret_values:

if patch_size is None:
patch_size = torch.as_tensor(encodings.shape[1:])

slices = get_slices(patch_idx)

u_names, u_idx, u_inverse, u_counts = np.unique(
names, return_counts=True, return_index=True, return_inverse=True
)

slices = zip(*(
tuple(np.s_[start:stop] for start, stop in dim)
for dim in (np.asarray((patch_idx, patch_idx+1)).T * patch_size)
))
for name, image_index, count in zip(u_names, img_idx[u_idx], u_counts):
current_count = counts.setdefault(name, np.asarray(dataset._lengths[image_index]))
current_array = arrays.setdefault(name, torch.empty(
size=tuple(patch_size*dataset._sizes[image_index]),
dtype=encodings.dtype,
device=encodings.device
))

for name, slice_idx, value in zip(names, slices, encodings):
arrays[name][slice_idx] = value
mask = u_inverse == int(image_index)
current_array[slices[mask].swapaxes(0, 1)] = encodings[mask]
current_count -= count # persistent because of np.array

yield from pop_array_if_done_(arrays, counts, u_names, u_counts).items()
if current_count == 0:
counts.pop(name)
yield name, (arr := arrays.pop(name).cpu().numpy()).astype(np.min_scalar_type(arr.max()))


@torch.no_grad()
@torch.autocast('cuda')
def run_eval(model, dataset, batch_size=75):
device = torch.device('cuda')
def run_eval(model, dataset, batch_size=1800):

dataloader = DataLoader(
dataset,
batch_size=batch_size,
pin_memory=True,
num_workers=6
num_workers=6,
prefetch_factor=10
)

torch.backends.cudnn.benchmark = True
device = torch.device('cuda')

model = model.to(device)
model.eval()

def extract_path(path: str) -> str:
return Path(path).parent.stem + '/' + Path(path).stem

max_pool = None

for imgs, labels, (img_index, patch_index, img_path, label_path) in dataloader:
imgs, labels = imgs.to(device), labels.to(device)

encodings, encoding_indices, encoding_loss = tuple(zip(*model.encoder(imgs)))[0]
imgs, labels = (
imgs.to(device, non_blocking=True, dtype=torch.half),
labels.to(device, non_blocking=True, dtype=torch.int16)
)

with torch.autocast('cuda'):
encodings, encoding_indices, encoding_loss = tuple(zip(*model.encoder(imgs)))[0]

if max_pool is None:
max_pool = partial(F.adaptive_max_pool2d, output_size=encoding_indices.shape[-2:])

labels_pooled = max_pool(labels.to(torch.float)).type(labels.type())
labels_pooled = max_pool(labels.to(torch.half)).to(labels.dtype).squeeze()

yield (
torch.concat([encoding_indices, labels_pooled.squeeze()]).cpu().numpy(),
torch.concat([img_index[:, None], patch_index], dim=1).repeat(2, 1).T.cpu().numpy(),
np.asarray(list(map(extract_path, chain(img_path, label_path))))
(data, list(map(extract_path, paths)), img_index, patch_index)
for data, paths in (
(encoding_indices.to(torch.int16), img_path), # make a ndim < 2**15 assumption
(labels_pooled, label_path)
)
)


Expand All @@ -136,6 +143,7 @@ def main(

# TODO: smarter checkpoint finder instead of just taking the last checkpoint
model = VQAE.load_from_checkpoint(sorted(checkpoint_path)[-1]) # type: ignore
del model.decoder # don't need the decoder

GlobalHydra.instance().clear()
with initialize_config_dir(str(ckpt_folder / '.hydra')):
Expand All @@ -157,11 +165,13 @@ def main(

for train_stage in ('train', 'validation', 'test'): # TODO: replace with Enum
dataset = instantiate({**dataset_config, **{'train': train_stage}})
for array_name, array in get_encodings(model, dataset):
for array_name, array in tqdm.tqdm(
get_encodings(model, dataset),
total=len(dataset._lengths) * 2 # FIXME: remove hardcoded 2
):
out_path = ckpt_folder / 'encodings' / (array_name + '.npy')
out_path.parent.mkdir(parents=True, exist_ok=True)
np.save(str(out_path), array)
logging.info(f"Saved array to {out_path}")


def find_ckpt_folder(path: Path, pattern: Iterable[str]):
Expand Down

0 comments on commit e8a3bb7

Please sign in to comment.