From e8a3bb75de2eb722ee98ee19919affdffd0cf9cf Mon Sep 17 00:00:00 2001 From: Robert Jan Schlimbach Date: Mon, 22 Nov 2021 13:51:34 +0100 Subject: [PATCH] MVP2 extract_embeddings.py --- datamodules/camelyon16.py | 35 +++--- load-asap.sh | 2 + .../extract_embeddings/extract_embeddings.py | 112 ++++++++++-------- 3 files changed, 80 insertions(+), 69 deletions(-) diff --git a/datamodules/camelyon16.py b/datamodules/camelyon16.py index bac0ae9..018a871 100644 --- a/datamodules/camelyon16.py +++ b/datamodules/camelyon16.py @@ -138,8 +138,8 @@ def __init__( train_frac: float, **throwaway_kwargs ): - modality_folders = ('images', 'masks') - modality_postfixes = ('', '_mask') + modality_folders = ('images', 'masks', 'tissue_masks') + modality_postfixes = ('', '_mask', '_tissue') find_pairs = functools.partial( _find_image_mask_pairs_paths, @@ -148,7 +148,7 @@ def __init__( modality_postfixes=modality_postfixes ) - self.image_paths, self.mask_paths = ( + self.image_paths, self.mask_paths, self.tissue_mask_paths = ( find_pairs(pattern='test') if train == 'test' else _train_val_split_paths( @@ -189,7 +189,7 @@ def _sizes(self) -> np.ndarray: # np.ndarray[..., 2] @functools.lru_cache(maxsize=1) def _get_paths(self, index): - return self.image_paths[index], self.mask_paths[index] + return self.image_paths[index], self.mask_paths[index], self.tissue_mask_paths[index] def __getitem__(self, index) -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]: """ @@ -209,8 +209,8 @@ def __getitem__(self, index) -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]] patch_index % self._sizes[img_index, 1] # noqa[E222] )) - paths = self._get_paths(img_index) - patch, labels = ( + paths = self._get_paths(img_index)[:2] # img, mask, _ + patch, label = ( ImageReader(path, self.spacing_tolerance).read( self.spacing, *(patch_indices * self.patch_size), # row, col @@ -221,9 +221,12 @@ def __getitem__(self, index) -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]] ) return ( # type: ignore - *((patch, labels) + *((patch, label) if self.transforms is None - else self.transforms(image=patch, mask=labels)).values(), + else ( + (transformed := self.transforms(image=patch, mask=label))['image'], + transformed['mask'] + )), (img_index, patch_indices, *paths) ) @@ -241,23 +244,19 @@ def _train_val_split_paths( This function is obviously extremely overengineered. """ assert mode in ('train', 'validation') - temp = [] - for images, masks in modality_arrays: - + for modality_array in modality_arrays: length, split_index = ( - (ln := len(images)), + (ln := len(modality_array[0])), tf if 0 < (tf := round(ln * split_frac)) < ln else 1 if tf == 0 else tf - 1 ) - temp.append( - [ - elem[(slice(split_index) if mode == 'train' else slice(split_index, length))] - for elem in (images, masks) - ] - ) + temp.append([ + elem[(slice(split_index) if mode == 'train' else slice(split_index, length))] + for elem in modality_array + ]) return tuple( # type: ignore map( diff --git a/load-asap.sh b/load-asap.sh index c325a9c..3dc4973 100755 --- a/load-asap.sh +++ b/load-asap.sh @@ -5,6 +5,8 @@ module load Miniconda3/4.7.12.1 module load ASAP/8c9a8fb-fosscuda-2020a +source deactivate +source deactivate source activate 2D-VQ-AE-2 export PYTHONPATH=$PYTHONPATH:~/.conda/envs/2D-VQ-AE-2/lib/python3.8/site-packages/ diff --git a/scripts/extract_embeddings/extract_embeddings.py b/scripts/extract_embeddings/extract_embeddings.py index 79c5e09..c5140e2 100644 --- a/scripts/extract_embeddings/extract_embeddings.py +++ b/scripts/extract_embeddings/extract_embeddings.py @@ -1,19 +1,19 @@ from __future__ import annotations -import logging -from collections.abc import Iterable, Sequence +import functools +from collections.abc import Iterable from dataclasses import dataclass from functools import partial from glob import glob -from itertools import chain, starmap +from itertools import chain from operator import attrgetter from pathlib import Path - -import numpy as np import hydra +import numpy as np import torch import torch.nn.functional as F +import tqdm from hydra import compose, initialize_config_dir from hydra.core.global_hydra import GlobalHydra from hydra.utils import instantiate @@ -43,80 +43,87 @@ def get_encodings( model: nn.Module, dataset: Dataset ): - def setdefault_(array_dict: dict, array_count: dict, names, img_idx): - u_info = np.unique(names, return_index=True, return_counts=True) - for unique_name, unique_index, unique_count in zip(*u_info): - if unique_name not in array_dict: - img_index = img_idx[unique_index] - array_dict[unique_name] = np.empty( - patch_size * (dataset._sizes[img_index]), - dtype=encodings.dtype - ) - array_count[unique_name] = dataset._lengths[img_index] - return u_info - - def pop_array_if_done_(array_dict, array_count, unique_names, unique_counts): - return_dict = {} - for u_name, count in zip(unique_names, unique_counts): - array_count[u_name] -= count - if array_count[u_name] == 0: - array_count.pop(u_name) - return_dict[u_name] = array_dict.pop(u_name) - elif array_count[u_name] < 0: - raise RuntimeError("steps left is less than 0, this shouldn't happen") - return return_dict + @functools.lru_cache(maxsize=1) + def get_slices(patch_idx): + return np.asarray([ + np.mgrid[tuple(slice(*map(int, d)) for d in dim)] + for dim in (torch.cat([patch_idx[None], patch_idx[None]+1]).T.swapaxes(0, 1) * patch_size) + ]) arrays, counts, patch_size = {}, {}, None - for encodings, idx, names in run_eval(model, dataset): - img_idx, patch_idx = idx[0], idx[1:].T - if patch_size is None: - patch_size = np.asarray(encodings.shape[1:]) - u_names, _, u_counts = setdefault_(arrays, counts, names, img_idx) + for ret_values in run_eval(model, dataset): + for (encodings, names, img_idx, patch_idx) in ret_values: + + if patch_size is None: + patch_size = torch.as_tensor(encodings.shape[1:]) + + slices = get_slices(patch_idx) + + u_names, u_idx, u_inverse, u_counts = np.unique( + names, return_counts=True, return_index=True, return_inverse=True + ) - slices = zip(*( - tuple(np.s_[start:stop] for start, stop in dim) - for dim in (np.asarray((patch_idx, patch_idx+1)).T * patch_size) - )) + for name, image_index, count in zip(u_names, img_idx[u_idx], u_counts): + current_count = counts.setdefault(name, np.asarray(dataset._lengths[image_index])) + current_array = arrays.setdefault(name, torch.empty( + size=tuple(patch_size*dataset._sizes[image_index]), + dtype=encodings.dtype, + device=encodings.device + )) - for name, slice_idx, value in zip(names, slices, encodings): - arrays[name][slice_idx] = value + mask = u_inverse == int(image_index) + current_array[slices[mask].swapaxes(0, 1)] = encodings[mask] + current_count -= count # persistent because of np.array - yield from pop_array_if_done_(arrays, counts, u_names, u_counts).items() + if current_count == 0: + counts.pop(name) + yield name, (arr := arrays.pop(name).cpu().numpy()).astype(np.min_scalar_type(arr.max())) @torch.no_grad() -@torch.autocast('cuda') -def run_eval(model, dataset, batch_size=75): - device = torch.device('cuda') +def run_eval(model, dataset, batch_size=1800): dataloader = DataLoader( dataset, batch_size=batch_size, pin_memory=True, - num_workers=6 + num_workers=6, + prefetch_factor=10 ) + torch.backends.cudnn.benchmark = True + device = torch.device('cuda') + model = model.to(device) + model.eval() def extract_path(path: str) -> str: return Path(path).parent.stem + '/' + Path(path).stem max_pool = None + for imgs, labels, (img_index, patch_index, img_path, label_path) in dataloader: - imgs, labels = imgs.to(device), labels.to(device) - encodings, encoding_indices, encoding_loss = tuple(zip(*model.encoder(imgs)))[0] + imgs, labels = ( + imgs.to(device, non_blocking=True, dtype=torch.half), + labels.to(device, non_blocking=True, dtype=torch.int16) + ) + + with torch.autocast('cuda'): + encodings, encoding_indices, encoding_loss = tuple(zip(*model.encoder(imgs)))[0] if max_pool is None: max_pool = partial(F.adaptive_max_pool2d, output_size=encoding_indices.shape[-2:]) - labels_pooled = max_pool(labels.to(torch.float)).type(labels.type()) + labels_pooled = max_pool(labels.to(torch.half)).to(labels.dtype).squeeze() yield ( - torch.concat([encoding_indices, labels_pooled.squeeze()]).cpu().numpy(), - torch.concat([img_index[:, None], patch_index], dim=1).repeat(2, 1).T.cpu().numpy(), - np.asarray(list(map(extract_path, chain(img_path, label_path)))) + (data, list(map(extract_path, paths)), img_index, patch_index) + for data, paths in ( + (encoding_indices.to(torch.int16), img_path), # make a ndim < 2**15 assumption + (labels_pooled, label_path) + ) ) @@ -136,6 +143,7 @@ def main( # TODO: smarter checkpoint finder instead of just taking the last checkpoint model = VQAE.load_from_checkpoint(sorted(checkpoint_path)[-1]) # type: ignore + del model.decoder # don't need the decoder GlobalHydra.instance().clear() with initialize_config_dir(str(ckpt_folder / '.hydra')): @@ -157,11 +165,13 @@ def main( for train_stage in ('train', 'validation', 'test'): # TODO: replace with Enum dataset = instantiate({**dataset_config, **{'train': train_stage}}) - for array_name, array in get_encodings(model, dataset): + for array_name, array in tqdm.tqdm( + get_encodings(model, dataset), + total=len(dataset._lengths) * 2 # FIXME: remove hardcoded 2 + ): out_path = ckpt_folder / 'encodings' / (array_name + '.npy') out_path.parent.mkdir(parents=True, exist_ok=True) np.save(str(out_path), array) - logging.info(f"Saved array to {out_path}") def find_ckpt_folder(path: Path, pattern: Iterable[str]):