diff --git a/CHANGELOG b/CHANGELOG index 2b954cc..467deca 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,3 +1,5 @@ +0.4.0 + - feat: move from 2d to 3d array operations (#12) 0.3.2 - maintenance release 0.3.1 diff --git a/docs/README.md b/docs/README.md index 694748f..e8441e1 100644 --- a/docs/README.md +++ b/docs/README.md @@ -6,6 +6,7 @@ To install the requirements for building the documentation, run To compile the documentation, run + cd docs sphinx-build . _build diff --git a/docs/requirements.txt b/docs/requirements.txt index 15408de..87e4bc3 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,4 +1,3 @@ -sphinx==4.3.0 -sphinxcontrib.bibtex>=2.0 -sphinx_rtd_theme==1.0 - +sphinx +sphinxcontrib.bibtex +sphinx_rtd_theme diff --git a/docs/sec_array_layout.rst b/docs/sec_array_layout.rst new file mode 100644 index 0000000..d5b26cd --- /dev/null +++ b/docs/sec_array_layout.rst @@ -0,0 +1,62 @@ +Data Array Layouts +================== + +.. _sec_doc_array_layout: + +Since version 0.4.0, `qpretrieve` accepts 3D (z,y,x) arrays as input. +Additionally, it **always** returns data as the 3D array layout. + +We use the term "*array layout*" to define the different ways to represent data. +The currently accepted array layouts are: 2D, RGB, RGBA, 3D. + +Summary of allowed Array Layouts:: + + Input -> Output + 2D (y,x) -> 3D (1,y,x) + RGB (y,x,3) -> 3D (1,y,x) + RGBA (y,x,4) -> 3D (1,y,x) + 3D (z,y,x) -> 3D (z,y,x) + + +Notes on RGB/RGBA Array Layouts +------------------------------- + +**Inputting RGB(A)**: See the Notes section of +:func:`.convert_data_to_3d_array_layout` for extra information on RGB(A) +array layouts. + +**Outputting RGB(A)**: See the Notes section of +:meth:`.OffAxisHologram.get_data_with_input_layout` or +:meth:`.QLSInterferogram.get_data_with_input_layout` for information on +outputting of RGB(A) array layouts. + + +Converting to and from Array Layouts +------------------------------------ + +`qpretrieve` will automatically handle the above allowed array layouts. +In other words, if you provide any 2D, RGB, RGBA, or 3D data as input to +:class:`.OffAxisHologram` or :class:`.QLSInterferogram` +the class will handle everything for you. + +However, if you want to have your processed data in the same array layout as when +you inputted it, then you can use the convenience method +:meth:`get_data_with_input_layout` to do just that. For example, if +your input data was a 2D array, you can get the processed field, phase, +amplitude etc like so: + +.. code-block:: python + + # 2D data inputted + oah = qpretrieve.OffAxisHologram(hologram_2d) + # do some processing + ... + # get your data as a 2D array layout + field_2d = oah.get_data_with_input_layout("field") + phase_2d = oah.get_data_with_input_layout("phase") + amplitude_2d = oah.get_data_with_input_layout("amplitude") + + # you can also use the class attributes + field_2d = oah.get_data_with_input_layout(oah.field) + phase_2d = oah.get_data_with_input_layout(oah.phase) + amplitude_2d = oah.get_data_with_input_layout(oah.amplitude) diff --git a/docs/sec_basic_use.rst b/docs/sec_basic_use.rst new file mode 100644 index 0000000..c4ba24e --- /dev/null +++ b/docs/sec_basic_use.rst @@ -0,0 +1,78 @@ +Basic Use of qpretrieve +======================= + +To run this code locally, you can download the example dataset from the +`qpretrieve repository `_ + + +.. code-block:: python + + import qpretrieve + + # load your experimental data (data from the qpretrieve repository) + edata = np.load("qpretrieve/examples/data/hologram_cell.npz") + data = edata["data"] + data_bg = edata["bg_data"] + + # create an off-axis hologram object, to process the holography data + oah = qpretrieve.OffAxisHologram(data) + oah.run_pipeline(filter_name=filter_name, filter_size=filter_size) + # process background hologram + oah_bg = qpretrieve.OffAxisHologram(data=data_2d_bg) + oah_bg.process_like(oah) + + print(f"Original Hologram shape: {data.shape}") + print(f"Processed Field shape: {oah.field.shape}") + + # Now you can look at the phase data + phase_corrected = oah.phase - oah_bg.phase + + + +How qpretrieve now handles data +------------------------------- + +In older versions of qpretrieve, only a single image could be processed at a time. +Since version (0.4.0), it accepts 3D array layouts, and always returns +the data as 3D, regardless of the input array layout. + +Click here for more :ref:`information on array layouts `. +There are speed benchmarks given for different inputs in the +:ref:`examples ` section. + + +New version example code and output: +.................................... + +.. code-block:: python + + import numpy as np + import qpretrieve + + hologram = np.ones(shape=(256, 256)) + oah = qpretrieve.OffAxisHologram(hologram) + oah.run_pipeline() + assert oah.field.shape == (1, 256, 256) # <- now a 3D array is returned + # if you want the original array layout (2d) + field_2d = oah.get_data_with_input_layout("field") + assert field_2d.shape == (256, 256) + + # this means you can input 3D arrays + hologram_3d = np.ones(shape=(50, 256, 256)) + oah = qpretrieve.OffAxisHologram(hologram_3d) + oah.run_pipeline() + assert oah.field.shape == (50, 256, 256) # <- always a 3D array + + +Old version example code and output: +.................................... + +.. code-block:: python + + import numpy as np + import qpretrieve # versions older than 0.4.0 + + hologram = np.ones(shape=(256, 256)) + oah = qpretrieve.OffAxisHologram(hologram) + oah.run_pipeline() + assert oah.field.shape == hologram.shape # <- old version returned a 2D array diff --git a/docs/sec_code_reference.rst b/docs/sec_code_reference.rst index dcf88ec..340fa19 100644 --- a/docs/sec_code_reference.rst +++ b/docs/sec_code_reference.rst @@ -60,3 +60,12 @@ Quadriwave lateral shearing interferometry (QLSI) .. automodule:: qpretrieve.interfere.if_qlsi :members: :inherited-members: + +Data Array Layout +================= + +.. _sec_code_array_layout: + +.. automodule:: qpretrieve.data_array_layout + :members: + :inherited-members: diff --git a/docs/sec_examples.rst b/docs/sec_examples.rst index b1e99a1..63cdf5b 100644 --- a/docs/sec_examples.rst +++ b/docs/sec_examples.rst @@ -12,3 +12,5 @@ Examples .. fancy_include:: filter_visualization.py .. fancy_include:: fourier_scale.py + +.. fancy_include:: fft_batch_speeds.py diff --git a/docs/sec_getting_started.rst b/docs/sec_getting_started.rst index 96721a4..2892aa3 100644 --- a/docs/sec_getting_started.rst +++ b/docs/sec_getting_started.rst @@ -6,4 +6,6 @@ Getting started :maxdepth: 2 sec_installation + sec_basic_use + sec_array_layout sec_userapi diff --git a/docs/sec_userapi.rst b/docs/sec_userapi.rst index af68d78..816421c 100644 --- a/docs/sec_userapi.rst +++ b/docs/sec_userapi.rst @@ -21,4 +21,4 @@ Then your analysis could be as simple as With ``dhm``, an instance of :class:`.OffAxisHologram`, you now have full access to all intermediate computation results. You can pass additional keyword arguments during instantiation or pass them to -:func:`.OffAxisHologram.run_pipeline`. +:meth:`.OffAxisHologram.run_pipeline`. diff --git a/examples/fft_batch_speeds.png b/examples/fft_batch_speeds.png new file mode 100644 index 0000000..0675ac9 Binary files /dev/null and b/examples/fft_batch_speeds.png differ diff --git a/examples/fft_batch_speeds.py b/examples/fft_batch_speeds.py new file mode 100644 index 0000000..556b7e0 --- /dev/null +++ b/examples/fft_batch_speeds.py @@ -0,0 +1,91 @@ +"""Fourier Transform speed benchmarks for OAH + +This example visualizes the speed for different batch sizes for +the available FFT Filters. The y-axis shows the average speed of a pipeline +run for the Off-Axis Hologram class :class:`.OffAxisHologram`, including +background data processing. Therefore, four FFTs are run per pipeline. + +- Optimum batch size is between 64 and 256 for 256x256pix imgs (incl padding), + but will be limited by your computer's RAM. +- Here, batch size is the size of the 3D stack in z. +- Note that each pipeline runs 4 FFTs. For example, batch 8 runs 8*4=32 FFTs. + +""" +import time +import matplotlib.pylab as plt +import numpy as np +import qpretrieve +from qpretrieve.data_array_layout import convert_data_to_3d_array_layout +from qpretrieve.fourier import FFTFilterNumpy, FFTFilterPyFFTW + +# load the experimental data +edata = np.load("./data/hologram_cell.npz") + +n_transforms_list = [8, 16, 32, 64, 128, 256] +subtract_mean = True +padding = True +# we take the PyFFTW speeds from the second run +fft_interfaces = [FFTFilterNumpy, FFTFilterPyFFTW, FFTFilterPyFFTW] +filter_name = "disk" +filter_size = 1 / 2 +speed_norms = {} + +# load and prep the data +data_2d = edata["data"].copy() +data_2d_bg = edata["bg_data"].copy() +data_3d_prep, _ = convert_data_to_3d_array_layout(data_2d) +data_3d_bg_prep, _ = convert_data_to_3d_array_layout(data_2d_bg) + +for fft_interface in fft_interfaces: + results = {} + for n_transforms in n_transforms_list: + print(f"Running {n_transforms} transforms with " + f"{fft_interface.__name__}") + + # create batches + data_3d = np.repeat(data_3d_prep, repeats=n_transforms, axis=0) + data_3d_bg = np.repeat(data_3d_bg_prep, repeats=n_transforms, axis=0) + + assert data_3d.shape == data_3d_bg.shape == (n_transforms, + edata["data"].shape[0], + edata["data"].shape[1]) + + t0 = time.time() + holo = qpretrieve.OffAxisHologram(data=data_3d, + fft_interface=fft_interface, + subtract_mean=subtract_mean, + padding=padding) + holo.run_pipeline(filter_name=filter_name, filter_size=filter_size) + bg = qpretrieve.OffAxisHologram(data=data_3d_bg) + bg.process_like(holo) + t1 = time.time() + results[n_transforms] = t1 - t0 + + speed_norm = [timing / b_size for b_size, timing in results.items()] + # the initial PyFFTW run (incl wisdom calc is overwritten here) + speed_norms[fft_interface.__name__] = speed_norm + +# setup plot +fig, axes = plt.subplots(1, 1, figsize=(8, 5)) +ax1 = axes +width = 0.25 # the width of the bars +multiplier = 0 +x_pos = np.arange(len(n_transforms_list)) +colors = ["darkmagenta", "lightseagreen"] + +for (name, speed), color in zip(speed_norms.items(), colors): + offset = width * multiplier + ax1.bar(x_pos + offset, speed, width, label=name, + color=color, edgecolor='k') + multiplier += 1 + +ax1.set_xticks(x_pos + (width / 2), labels=n_transforms_list) +ax1.set_xlabel("Input hologram batch size") +ax1.set_ylabel("OAH processing time [Time / batch size] (s)") +ax1.legend(loc='upper right', fontsize="large") + +plt.suptitle("Batch processing time for Off-Axis Hologram\n" + "(data+bg_data)") +plt.tight_layout() +# plt.show() +plt.savefig("fft_batch_speeds.png", dpi=150) diff --git a/examples/requirements.txt b/examples/requirements.txt new file mode 100644 index 0000000..6ccafc3 --- /dev/null +++ b/examples/requirements.txt @@ -0,0 +1 @@ +matplotlib diff --git a/qpretrieve/data_array_layout.py b/qpretrieve/data_array_layout.py new file mode 100644 index 0000000..b614895 --- /dev/null +++ b/qpretrieve/data_array_layout.py @@ -0,0 +1,122 @@ +""" +Module that provides convenience functions for converting data between +array layouts. +""" + +import numpy as np + + +def get_allowed_array_layouts() -> list: + return [ + "rgb", + "rgba", + "3d", + "2d", + ] + + +def convert_data_to_3d_array_layout( + data: np.ndarray) -> tuple[np.ndarray, str]: + """Convert the data to the 3d array_layout + + Returns + ------- + data_out + 3d version of the data + array_layout + original array layout for future reference + + Notes + ----- + If input is either a RGB or RGBA array layout as input, the first + channel is taken as the image to process. In other words, it is assumed + that all channels contain the same information, so the first channel is + used. 3D RGB/RGBA array layouts, such as (50, 256, 256, 3), are not allowed + (yet). + + """ + if len(data.shape) == 3: + if data.shape[-1] in [1, 2, 3]: + # take the first slice (we have alpha or RGB information) + data, array_layout = _convert_rgb_to_3d(data) + elif data.shape[-1] == 4: + # take the first slice (we have alpha or RGB information) + data, array_layout = _convert_rgba_to_3d(data) + else: + # we have a 3D image stack (z, y, x) + data, array_layout = data, "3d" + elif len(data.shape) == 2: + # we have a 2D image (y, x). convert to (z, y, z) + data, array_layout = _convert_2d_to_3d(data) + else: + raise ValueError(f"data_input shape must be 2d or 3d, " + f"got shape {data.shape}.") + return data.copy(), array_layout + + +def convert_3d_data_to_array_layout( + data: np.ndarray, array_layout: str) -> np.ndarray: + """Convert the 3d data to the desired `array_layout`. + + Returns + ------- + data_out : np.ndarray + input `data` with the given `array layout` + + Notes + ----- + Currently, this function is limited to converting from 3d to other + array layouts. Perhaps if there is demand in the future, + this can be generalised for other conversions. + + """ + assert array_layout in get_allowed_array_layouts(), ( + f"`array_layout` not allowed. " + f"Allowed layouts are: {get_allowed_array_layouts()}.") + assert len(data.shape) == 3, ( + f"The data should be 3d, got {len(data.shape)=}") + data = data.copy() + if array_layout == "rgb": + data = _convert_3d_to_rgb(data) + elif array_layout == "rgba": + data = _convert_3d_to_rgba(data) + elif array_layout == "3d": + data = data + else: + data = _convert_3d_to_2d(data) + return data + + +def _convert_rgb_to_3d(data_input: np.ndarray) -> tuple[np.ndarray, str]: + data = data_input[:, :, 0] + data = data[np.newaxis, :, :] + array_layout = "rgb" + return data, array_layout + + +def _convert_rgba_to_3d(data_input: np.ndarray) -> tuple[np.ndarray, str]: + data, _ = _convert_rgb_to_3d(data_input) + array_layout = "rgba" + return data, array_layout + + +def _convert_2d_to_3d(data_input: np.ndarray) -> tuple[np.ndarray, str]: + data = data_input[np.newaxis, :, :] + array_layout = "2d" + return data, array_layout + + +def _convert_3d_to_rgb(data_input: np.ndarray) -> np.ndarray: + data = data_input[0] + data = np.dstack((data, data, data)) + return data + + +def _convert_3d_to_rgba(data_input: np.ndarray) -> np.ndarray: + data = data_input[0] + data = np.dstack((data, data, data, np.ones_like(data))) + return data + + +def _convert_3d_to_2d(data_input: np.ndarray) -> np.ndarray: + return data_input[0] diff --git a/qpretrieve/filter.py b/qpretrieve/filter.py index aa8302c..eea7451 100644 --- a/qpretrieve/filter.py +++ b/qpretrieve/filter.py @@ -3,7 +3,6 @@ import numpy as np from scipy import signal - available_filters = [ "disk", "smooth disk", @@ -15,7 +14,10 @@ @lru_cache(maxsize=32) -def get_filter_array(filter_name, filter_size, freq_pos, fft_shape): +def get_filter_array(filter_name: str, + filter_size: float, + freq_pos: tuple[float, float], + fft_shape: tuple[int, int]) -> np.ndarray: """Create a Fourier filter for holography Parameters @@ -38,9 +40,9 @@ def get_filter_array(filter_name, filter_size, freq_pos, fft_shape): and must be between 0 and `max(fft_shape)/2` freq_pos: tuple of floats The position of the filter in frequency coordinates as - returned by :func:`nunpy.fft.fftfreq`. + returned by :func:`numpy.fft.fftfreq`. fft_shape: tuple of int - The shape of the Fourier transformed image for which the + The shape of the Fourier transformed image (2d) for which the filter will be applied. The shape must be squared (two identical integers). @@ -55,7 +57,7 @@ def get_filter_array(filter_name, filter_size, freq_pos, fft_shape): raise ValueError("The Fourier transformed data must have a squared " + f"shape, but the input shape is '{fft_shape}'! " + "Please pad your data properly before FFT.") - if not (0 < filter_size < max(fft_shape)/2): + if not (0 < filter_size < max(fft_shape) / 2): raise ValueError("The filter size cannot exceed more than half of " + "the Fourier space or be negative. Got a filter " + f"size of '{filter_size}' and a shape of " @@ -63,7 +65,7 @@ def get_filter_array(filter_name, filter_size, freq_pos, fft_shape): if not (0 <= min(np.abs(freq_pos)) <= max(np.abs(freq_pos)) - < max(fft_shape)/2): + < max(fft_shape) / 2): raise ValueError("The frequency position must be within the Fourier " + f"domain. Got '{freq_pos}' and shape " + f"'{fft_shape}'!") @@ -104,8 +106,10 @@ def get_filter_array(filter_name, filter_size, freq_pos, fft_shape): # TODO: avoid the np.roll, instead use the indices directly alpha = 0.1 rsize = int(min(fx.size, fy.size) * filter_size) * 2 - tukey_window_x = signal.tukey(rsize, alpha=alpha).reshape(-1, 1) - tukey_window_y = signal.tukey(rsize, alpha=alpha).reshape(1, -1) + tukey_window_x = signal.windows.tukey( + rsize, alpha=alpha).reshape(-1, 1) + tukey_window_y = signal.windows.tukey( + rsize, alpha=alpha).reshape(1, -1) tukey = tukey_window_x * tukey_window_y base = np.zeros(fft_shape) s1 = (np.array(fft_shape) - rsize) // 2 diff --git a/qpretrieve/fourier/__init__.py b/qpretrieve/fourier/__init__.py index 62d479c..c135f66 100644 --- a/qpretrieve/fourier/__init__.py +++ b/qpretrieve/fourier/__init__.py @@ -1,6 +1,8 @@ # flake8: noqa: F401 import warnings +from typing import Type +from .base import FFTFilter from .ff_numpy import FFTFilterNumpy try: @@ -11,7 +13,20 @@ PREFERRED_INTERFACE = None -def get_best_interface(): +def get_available_interfaces() -> list[Type[FFTFilter]]: + """Return a list of available FFT algorithms""" + interfaces = [ + FFTFilterPyFFTW, + FFTFilterNumpy, + ] + interfaces_available = [] + for interface in interfaces: + if interface is not None and interface.is_available: + interfaces_available.append(interface) + return interfaces_available + + +def get_best_interface() -> Type[FFTFilter]: """Return the fastest refocusing interface available If `pyfftw` is installed, :class:`.FFTFilterPyFFTW` diff --git a/qpretrieve/fourier/base.py b/qpretrieve/fourier/base.py index 0c430ac..7f337ba 100644 --- a/qpretrieve/fourier/base.py +++ b/qpretrieve/fourier/base.py @@ -6,6 +6,8 @@ import numpy as np from .. import filter +from ..utils import padding_3d, mean_3d +from ..data_array_layout import convert_data_to_3d_array_layout class FFTCache: @@ -35,12 +37,19 @@ def cleanup(key): class FFTFilter(ABC): - def __init__(self, data, subtract_mean=True, padding=2, copy=True): + def __init__(self, + data: np.ndarray, + subtract_mean: bool = True, + padding: int = 2, + copy: bool = True) -> None: r""" Parameters ---------- - data: 2d real-valued np.ndarray - The experimental input image + data + The experimental input real-valued image. Allowed input shapes are: + - 2d (y, x) + - 3d (z, y, x) + - rgb (y, x, 3) or rgba (y, x, 4) subtract_mean: bool If True, subtract the mean of `data` before performing the Fourier transform. This setting is recommended as it @@ -70,9 +79,16 @@ def __init__(self, data, subtract_mean=True, padding=2, copy=True): else: # convert integer-arrays to floating point arrays dtype = float + if not copy: + # numpy v2.x behaviour requires asarray with copy=False + copy = None data_ed = np.array(data, dtype=dtype, copy=copy) + # figure out what type of data we have, change it to 3d-stack + data_ed, self.orig_array_layout = convert_data_to_3d_array_layout( + data_ed) #: original data (with subtracted mean) self.origin = data_ed + # for `subtract_mean` and `padding`, we could use `np.atleast_3d` #: whether padding is enabled self.padding = padding #: whether the mean was subtracted @@ -81,14 +97,14 @@ def __init__(self, data, subtract_mean=True, padding=2, copy=True): # remove contributions of the central band # (this affects more than one pixel in the FFT # because of zero-padding) - data_ed -= data_ed.mean() + data_ed = mean_3d(data_ed) if padding: # zero padding size is next order of 2 logfact = np.log(padding * max(data_ed.shape)) - order = int(2 ** np.ceil(logfact / np.log(2))) - # this is faster than np.pad - datapad = np.zeros((order, order), dtype=dtype) - datapad[:data_ed.shape[0], :data_ed.shape[1]] = data_ed + order = np.ceil(logfact / np.log(2)) + size = int(2 ** order) + + datapad = padding_3d(data_ed, size, dtype) #: padded input data self.origin_padded = datapad data_ed = datapad @@ -107,7 +123,8 @@ def __init__(self, data, subtract_mean=True, padding=2, copy=True): self.fft_origin = fft_data else: #: frequency-shifted Fourier transform - self.fft_origin = np.fft.fftshift(self._init_fft(data_ed)) + self.fft_origin = np.fft.fftshift( + self._init_fft(data_ed), axes=(-2, -1)) # Add it to the cached FFTs if copy: FFTCache.add_item(weakref_key, data, self.fft_origin) @@ -119,13 +136,13 @@ def __init__(self, data, subtract_mean=True, padding=2, copy=True): self.fft_used = None @property - def shape(self): + def shape(self) -> tuple: """Shape of the Fourier transform data""" return self.fft_origin.shape @property @abstractmethod - def is_available(self): + def is_available(self) -> bool: """Whether this method is available given current hardware/software""" return True @@ -153,7 +170,7 @@ def _init_fft(self, data): def filter(self, filter_name: str, filter_size: float, freq_pos: (float, float), - scale_to_filter: bool | float = False): + scale_to_filter: bool | float = False) -> np.ndarray: """ Parameters ---------- @@ -175,7 +192,7 @@ def filter(self, filter_name: str, filter_size: float, and must be between 0 and `max(fft_shape)/2` freq_pos: tuple of floats The position of the filter in frequency coordinates as - returned by :func:`nunpy.fft.fftfreq`. + returned by :func:`numpy.fft.fftfreq`. scale_to_filter: bool or float Crop the image in Fourier space after applying the filter, effectively removing surplus (zero-padding) data and @@ -220,36 +237,41 @@ def filter(self, filter_name: str, filter_size: float, filter_name=filter_name, filter_size=filter_size, freq_pos=freq_pos, - fft_shape=self.fft_origin.shape) + # only take shape of a single fft + fft_shape=self.fft_origin.shape[-2:]) fft_filtered = self.fft_origin * filt_array - px = int(freq_pos[0] * self.shape[0]) - py = int(freq_pos[1] * self.shape[1]) - fft_used = np.roll(np.roll(fft_filtered, -px, axis=0), -py, axis=1) + px = int(freq_pos[0] * self.shape[-2]) + py = int(freq_pos[1] * self.shape[-1]) + fft_used = np.roll(np.roll( + fft_filtered, -px, axis=-2), -py, axis=-1) if scale_to_filter: # Determine the size of the cropping region. # We compute the "radius" of the region, so we can # crop the data left and right from the center of the # Fourier domain. - osize = fft_filtered.shape[0] # square shaped + osize = fft_filtered.shape[-2] # square shaped crad = int(np.ceil(filter_size * osize * scale_to_filter)) ccent = osize // 2 cslice = slice(ccent - crad, ccent + crad) # We now have the interesting peak already shifted to # the first entry of our array in `shifted`. - fft_used = fft_used[cslice, cslice] + fft_used = fft_used[:, cslice, cslice] + + field = self._ifft(np.fft.ifftshift(fft_used, axes=(-2, -1))) - field = self._ifft(np.fft.ifftshift(fft_used)) if self.padding: # revert padding - sx, sy = self.origin.shape + sx, sy = self.origin.shape[-2:] if scale_to_filter: sx = int(np.ceil(sx * 2 * crad / osize)) sy = int(np.ceil(sy * 2 * crad / osize)) - field = field[:sx, :sy] + + field = field[:, :sx, :sy] + if scale_to_filter: # Scale the absolute value of the field. This does not # have any influence on the phase, but on the amplitude. - field *= (2 * crad / osize)**2 + field *= (2 * crad / osize) ** 2 # Add FFT to cache # (The cache will only be cleared if this instance is deleted) FFTCache.add_item(weakref_key, self.fft_origin, diff --git a/qpretrieve/fourier/ff_numpy.py b/qpretrieve/fourier/ff_numpy.py index ad5d18b..0ac86c8 100644 --- a/qpretrieve/fourier/ff_numpy.py +++ b/qpretrieve/fourier/ff_numpy.py @@ -10,7 +10,7 @@ class FFTFilterNumpy(FFTFilter): # always available, because numpy is a dependency is_available = True - def _init_fft(self, data): + def _init_fft(self, data: np.ndarray) -> np.ndarray: """Perform initial Fourier transform of the input data Parameters @@ -23,8 +23,8 @@ def _init_fft(self, data): fft_fdata: 2d complex-valued ndarray Fourier transform `data` """ - return np.fft.fft2(data) + return np.fft.fft2(data, axes=(-2, -1)) - def _ifft(self, data): + def _ifft(self, data: np.ndarray) -> np.ndarray: """Perform inverse Fourier transform""" - return np.fft.ifft2(data) + return np.fft.ifft2(data, axes=(-2, -1)) diff --git a/qpretrieve/fourier/ff_pyfftw.py b/qpretrieve/fourier/ff_pyfftw.py index 7094f10..cd0c7c0 100644 --- a/qpretrieve/fourier/ff_pyfftw.py +++ b/qpretrieve/fourier/ff_pyfftw.py @@ -1,4 +1,5 @@ import multiprocessing as mp +import numpy as np import pyfftw @@ -8,10 +9,9 @@ class FFTFilterPyFFTW(FFTFilter): """Fourier transform using `PyFFTW `_ """ - # always available, because numpy is a dependency is_available = True - def _init_fft(self, data): + def _init_fft(self, data: np.ndarray) -> np.ndarray: """Perform initial Fourier transform of the input data Parameters @@ -27,19 +27,19 @@ def _init_fft(self, data): in_arr = pyfftw.empty_aligned(data.shape, dtype='complex128') out_arr = pyfftw.empty_aligned(data.shape, dtype='complex128') fft_obj = pyfftw.FFTW(in_arr, out_arr, - axes=(0, 1), + axes=(-2, -1), threads=mp.cpu_count()) in_arr[:] = data fft_obj() return out_arr - def _ifft(self, data): + def _ifft(self, data: np.ndarray) -> np.ndarray: """Perform inverse Fourier transform""" in_arr = pyfftw.empty_aligned(data.shape, dtype='complex128') - ou_arr = pyfftw.empty_aligned(data.shape, dtype='complex128') - fft_obj = pyfftw.FFTW(in_arr, ou_arr, axes=(0, 1), + out_arr = pyfftw.empty_aligned(data.shape, dtype='complex128') + fft_obj = pyfftw.FFTW(in_arr, out_arr, axes=(-2, -1), direction="FFTW_BACKWARD", ) in_arr[:] = data fft_obj() - return ou_arr + return out_arr diff --git a/qpretrieve/interfere/__init__.py b/qpretrieve/interfere/__init__.py index 088b3bf..9aecaa6 100644 --- a/qpretrieve/interfere/__init__.py +++ b/qpretrieve/interfere/__init__.py @@ -1,3 +1,4 @@ # flake8: noqa: F401 +from .base import BadFFTFilterError from .if_oah import OffAxisHologram from .if_qlsi import QLSInterferogram diff --git a/qpretrieve/interfere/base.py b/qpretrieve/interfere/base.py index f1a637a..7bd4323 100644 --- a/qpretrieve/interfere/base.py +++ b/qpretrieve/interfere/base.py @@ -1,8 +1,18 @@ +import warnings from abc import ABC, abstractmethod +from typing import Type import numpy as np -from ..fourier import get_best_interface +from ..fourier import get_best_interface, get_available_interfaces +from ..fourier.base import FFTFilter +from ..data_array_layout import ( + convert_data_to_3d_array_layout, convert_3d_data_to_array_layout +) + + +class BadFFTFilterError(ValueError): + pass class BaseInterferogram(ABC): @@ -15,11 +25,25 @@ class BaseInterferogram(ABC): "invert_phase": False, } - def __init__(self, data, subtract_mean=True, padding=2, copy=True, - **pipeline_kws): + def __init__(self, data: np.ndarray, + fft_interface: str | Type[FFTFilter] = "auto", + subtract_mean=True, padding=2, copy=True, + **pipeline_kws) -> None: """ Parameters ---------- + data + The experimental input real-valued image. Allowed input shapes are: + - 2d (y, x) + - 3d (z, y, x) + - rgb (y, x, 3) or rgba (y, x, 4) + fft_interface + A Fourier transform interface. + See :func:`qpretrieve.fourier.get_available_interfaces` + to get a list of implemented interfaces. + Default is "auto", which will use + :func:`qpretrieve.fourier.get_best_interface`. This is in line + with old behaviour. See Notes for more details. subtract_mean: bool If True, remove the mean of the hologram before performing the Fourier transform. This setting is recommended as it @@ -37,16 +61,42 @@ def __init__(self, data, subtract_mean=True, padding=2, copy=True, pipeline_kws: Any additional keyword arguments for :func:`run_pipeline` as defined in :const:`default_pipeline_kws`. + + Notes + ----- + For `fft_interface`, if you do not have the relevant package installed, + then an error will be raised. For example, setting + `fft_interface=FFTFilterPyFFTW` will fail if you do not have pyfftw + installed. + """ - ff_iface = get_best_interface() - if len(data.shape) == 3: - # take the first slice (we have alpha or RGB information) - data = data[:, :, 0] + if fft_interface is None: + raise BadFFTFilterError( + "`fft_interface` is set to None. If you want qpretrieve to " + "find the best FFT interface, set it to 'auto'. " + "If you are trying to use `FFTFilterPyFFTW`, " + "you must first install the pyfftw package.") + if fft_interface == 'auto': + self.ff_iface = get_best_interface() + else: + if fft_interface in get_available_interfaces(): + self.ff_iface = fft_interface + else: + raise BadFFTFilterError( + f"User-chosen FFT Interface '{fft_interface}' is not " + f"available. The available interfaces are: " + f"{get_available_interfaces()}.\n" + f"You can use `fft_interface='auto'` to get the best " + f"available interface.") + + # figure out what type of data we have, change it to 3d-stack + data, self.orig_array_layout = convert_data_to_3d_array_layout(data) + #: qpretrieve Fourier transform interface class - self.fft = ff_iface(data=data, - subtract_mean=subtract_mean, - padding=padding, - copy=copy) + self.fft = self.ff_iface(data=data, + subtract_mean=subtract_mean, + padding=padding, + copy=copy) #: originally computed Fourier transform self.fft_origin = self.fft.fft_origin #: filtered Fourier data from last run of `run_pipeline` @@ -58,29 +108,62 @@ def __init__(self, data, subtract_mean=True, padding=2, copy=True, self._phase = None self._amplitude = None + def get_data_with_input_layout(self, data: np.ndarray | str) -> np.ndarray: + """Convert `data` to the original input array layout. + + + Parameters + ---------- + data + Either an array (np.ndarray) or name (str) of the relevant `data`. + + Returns + ------- + data_out : np.ndarray + array in the original input array layout + + Notes + ----- + If `data` is the RGBA array layout, then the alpha (A) channel will be + an array of ones. + + """ + if isinstance(data, str): + if data == "fft": + data = "fft_filtered" + warnings.warn( + "You have asked for 'fft' which is a class. " + "Returning 'fft_filtered'. " + "Alternatively you could use 'fft_origin'.") + data = getattr(self, data) + return convert_3d_data_to_array_layout(data, self.orig_array_layout) + @property - def phase(self): + def phase(self) -> np.ndarray: """Retrieved phase information""" if self._phase is None: self.run_pipeline() return self._phase @property - def amplitude(self): + def amplitude(self) -> np.ndarray: """Retrieved amplitude information""" if self._amplitude is None: self.run_pipeline() return self._amplitude @property - def field(self): + def field(self) -> np.ndarray: """Retrieved amplitude information""" if self._field is None: self.run_pipeline() return self._field - def compute_filter_size(self, filter_size, filter_size_interpretation, - sideband_freq=None): + def compute_filter_size( + self, + filter_size: float, + filter_size_interpretation: str, + sideband_freq: tuple[float, float] = None) -> float: """Compute the actual filter size in Fourier space""" if filter_size_interpretation == "frequency": # convert frequency to frequency index @@ -94,18 +177,18 @@ def compute_filter_size(self, filter_size, filter_size_interpretation, raise ValueError("For sideband distance interpretation, " "`filter_size` must be between 0 and 1; " f"got '{filter_size}'!") - fsize = np.sqrt(np.sum(np.array(sideband_freq)**2)) * filter_size + fsize = np.sqrt(np.sum(np.array(sideband_freq) ** 2)) * filter_size elif filter_size_interpretation == "frequency index": # filter size given in Fourier index (number of Fourier pixels) # The user probably does not know that we are padding in # Fourier space, so we use the unpadded size and translate it. - if filter_size <= 0 or filter_size >= self.fft.shape[0] / 2: + if filter_size <= 0 or filter_size >= self.fft.shape[-2] / 2: raise ValueError("For frequency index interpretation, " + "`filter_size` must be between 0 and " - + f"{self.fft.shape[0] / 2}, got " + + f"{self.fft.shape[-2] / 2}, got " + f"'{filter_size}'!") # convert to frequencies (compatible with fx and fy) - fsize = filter_size / self.fft.shape[0] + fsize = filter_size / self.fft.shape[-2] else: raise ValueError("Invalid value for `filter_size_interpretation`: " + f"'{filter_size_interpretation}'") diff --git a/qpretrieve/interfere/if_oah.py b/qpretrieve/interfere/if_oah.py index 73bd4bd..7da2427 100644 --- a/qpretrieve/interfere/if_oah.py +++ b/qpretrieve/interfere/if_oah.py @@ -16,7 +16,7 @@ class OffAxisHologram(BaseInterferogram): } @property - def phase(self): + def phase(self) -> np.ndarray: """Retrieved phase information""" if self._field is None: self.run_pipeline() @@ -25,14 +25,15 @@ def phase(self): return self._phase @property - def amplitude(self): + def amplitude(self) -> np.ndarray: """Retrieved amplitude information""" if self._field is None: self.run_pipeline() + if self._amplitude is None: self._amplitude = np.abs(self._field) return self._amplitude - def run_pipeline(self, **pipeline_kws): + def run_pipeline(self, **pipeline_kws) -> np.ndarray: r"""Run OAH analysis pipeline Parameters @@ -64,6 +65,8 @@ def run_pipeline(self, **pipeline_kws): sideband_freq: tuple of floats Frequency coordinates of the sideband to use. By default, a heuristic search for the sideband is done. + If you pass a 3D array, the first hologram is used to + determine the sideband frequencies. invert_phase: bool Invert the phase data. """ @@ -73,7 +76,7 @@ def run_pipeline(self, **pipeline_kws): if pipeline_kws["sideband_freq"] is None: pipeline_kws["sideband_freq"] = find_peak_cosine( - self.fft.fft_origin) + self.fft.fft_origin[0]) # convert filter_size to frequency coordinates fsize = self.compute_filter_size( @@ -100,8 +103,9 @@ def run_pipeline(self, **pipeline_kws): return self.field -def find_peak_cosine(ft_data, copy=True): - """Find the side band position of a regular off-axis hologram +def find_peak_cosine( + ft_data: np.ndarray, copy: bool = True) -> tuple[float, float]: + """Find the side band position of a 2d regular off-axis hologram The Fourier transform of a cosine function (known as the striped fringe pattern in off-axis holography) results in diff --git a/qpretrieve/interfere/if_qlsi.py b/qpretrieve/interfere/if_qlsi.py index 38f2c39..ed1ceaf 100644 --- a/qpretrieve/interfere/if_qlsi.py +++ b/qpretrieve/interfere/if_qlsi.py @@ -24,12 +24,11 @@ class QLSInterferogram(BaseInterferogram): def __init__(self, data, reference=None, *args, **kwargs): super(QLSInterferogram, self).__init__(data, *args, **kwargs) - ff_iface = get_best_interface() if reference is not None: - self.fft_ref = ff_iface(data=reference, - subtract_mean=self.fft.subtract_mean, - padding=self.fft.padding) + self.fft_ref = self.ff_iface(data=reference, + subtract_mean=self.fft.subtract_mean, + padding=self.fft.padding) else: self.fft_ref = None @@ -39,24 +38,24 @@ def __init__(self, data, reference=None, *args, **kwargs): self._field = None @property - def amplitude(self): + def amplitude(self) -> np.ndarray: if self._amplitude is None: self.run_pipeline() return self._amplitude @property - def field(self): + def field(self) -> np.ndarray: if self._field is None: - self._field = self.amplitude * np.exp(1j*2*np.pi*self.phase) + self._field = self.amplitude * np.exp(1j * 2 * np.pi * self.phase) return self._field @property - def phase(self): + def phase(self) -> np.ndarray: if self._phase is None: self.run_pipeline() return self._phase - def run_pipeline(self, **pipeline_kws): + def run_pipeline(self, **pipeline_kws) -> np.ndarray: r"""Run QLSI analysis pipeline Parameters @@ -88,6 +87,8 @@ def run_pipeline(self, **pipeline_kws): sideband_freq: tuple of floats Frequency coordinates of the sideband to use. By default, a heuristic search for the sideband is done. + If you pass a 3D array, the first hologram is used to + determine the sideband frequencies. invert_phase: bool Invert the phase data. wavelength: float @@ -120,7 +121,7 @@ def run_pipeline(self, **pipeline_kws): if pipeline_kws["sideband_freq"] is None: pipeline_kws["sideband_freq"] = find_peaks_qlsi( - self.fft.fft_origin) + self.fft.fft_origin[0]) # convert filter_size to frequency coordinates fsize = self.compute_filter_size( @@ -172,8 +173,15 @@ def run_pipeline(self, **pipeline_kws): # Obtain the phase gradients in x and y by taking the argument # of Hx and Hy. - px = unwrap_phase(np.angle(hx)) - py = unwrap_phase(np.angle(hy)) + # Every image in the 3D stack must be treated individually with + # `unwrap_phase`. If we passed the 3D stack, then skimage would + # treat this as a 3D phase-unwrapping problem, which it is not [sic!]. + # see `tests.test_qlsi.test_qlsi_unwrap_phase_2d_3d`. + px = np.zeros_like(hx, dtype=float) + py = np.zeros_like(hy, dtype=float) + for i, (_hx, _hy) in enumerate(zip(hx, hy)): + px[i] = unwrap_phase(np.angle(_hx)) + py[i] = unwrap_phase(np.angle(_hy)) # Determine the angle by which we have to rotate the gradients in # order for them to be aligned with x and y. This angle is defined @@ -183,15 +191,15 @@ def run_pipeline(self, **pipeline_kws): # Pad the gradient information so that we can rotate with cropping # (keeping the image shape the same). # TODO: Make padding dependent on rotation angle to save time? - sx, sy = px.shape - gradpad1 = np.pad(px, ((sx // 2, sx // 2), (sy // 2, sy // 2)), + sx, sy = px.shape[-2:] + gradpad1 = np.pad(px, ((0, 0), (sx // 2, sx // 2), (sy // 2, sy // 2)), mode="constant", constant_values=0) - gradpad2 = np.pad(py, ((sx // 2, sx // 2), (sy // 2, sy // 2)), + gradpad2 = np.pad(py, ((0, 0), (sx // 2, sx // 2), (sy // 2, sy // 2)), mode="constant", constant_values=0) # Perform rotation of the gradients. - rotated1 = rotate_noreshape(gradpad1, -angle) - rotated2 = rotate_noreshape(gradpad2, -angle) + rotated1 = rotate_noreshape(gradpad1, -angle, axes=(-1, -2)) + rotated2 = rotate_noreshape(gradpad2, -angle, axes=(-1, -2)) # Retrieve the wavefront by integrating the vectorial components # (integrate the total differential). This magical approach @@ -204,22 +212,24 @@ def run_pipeline(self, **pipeline_kws): copy=False) # Compute the frequencies that correspond to the frequencies of the # Fourier-transformed image. - fx = np.fft.fftfreq(rfft.shape[0]).reshape(-1, 1) - fy = np.fft.fftfreq(rfft.shape[1]).reshape(1, -1) - fxy = -2*np.pi*1j * (fx + 1j*fy) - fxy[0, 0] = 1 + fx = np.fft.fftfreq(rfft.shape[-2]).reshape(-1, 1) + fy = np.fft.fftfreq(rfft.shape[-1]).reshape(1, -1) + fxy = -2 * np.pi * 1j * (fx + 1j * fy) + fxy = np.repeat(fxy[np.newaxis, :, :], repeats=rfft.shape[0], axis=0) + fxy[:, 0, 0] = 1 # The wavefront is the real part of the inverse Fourier transform # of the filtered (divided by frequencies) data. - wfr = rfft._ifft(np.fft.ifftshift(rfft.fft_origin)/fxy).real + wfr = rfft._ifft(np.fft.ifftshift(rfft.fft_origin, + axes=(-2, -1)) / fxy).real # Rotate the wavefront back and crop it so that the FOV matches # the input data. - raw_wavefront = rotate_noreshape(wfr, - angle)[sx//2:-sx//2, sy//2:-sy//2] + raw_wavefront = rotate_noreshape( + wfr, angle, axes=(-1, -2))[:, sx // 2:-sx // 2, sy // 2:-sy // 2] # Multiply by qlsi pitch term and the scaling factor to get # the quantitative wavefront. - scaling_factor = self.fft_origin.shape[0] / wfr.shape[0] + scaling_factor = self.fft_origin.shape[-2] / wfr.shape[-2] raw_wavefront *= qlsi_pitch_term * scaling_factor self._phase = raw_wavefront / wavelength * 2 * np.pi @@ -235,7 +245,10 @@ def run_pipeline(self, **pipeline_kws): return raw_wavefront -def find_peaks_qlsi(ft_data, periodicity=4, copy=True): +def find_peaks_qlsi( + ft_data: np.ndarray, + periodicity: int = 4, + copy: bool = True) -> tuple[tuple[float, float], tuple[float, float]]: """Find the two peaks in Fourier space for the x and y gradient Parameters @@ -285,24 +298,27 @@ def find_peaks_qlsi(ft_data, periodicity=4, copy=True): ft_data[:, cy - 3:cy + 3] = 0 # circular bandpass according to periodicity - fx = np.fft.fftshift(np.fft.fftfreq(ft_data.shape[0])).reshape(-1, 1) - fy = np.fft.fftshift(np.fft.fftfreq(ft_data.shape[1])).reshape(1, -1) - frmask1 = np.sqrt(fx**2 + fy**2) > 1/(periodicity*.8) + fx = np.fft.fftshift(np.fft.fftfreq(ft_data.shape[-2])).reshape(-1, 1) + fy = np.fft.fftshift(np.fft.fftfreq(ft_data.shape[-1])).reshape(1, -1) + frmask1 = np.sqrt(fx ** 2 + fy ** 2) > 1 / (periodicity * .8) frmask2 = np.sqrt(fx ** 2 + fy ** 2) < 1 / (periodicity * 1.2) ft_data[np.logical_or(frmask1, frmask2)] = 0 # find the peak in the left part - am1 = np.argmax(np.abs(ft_data*(fy < 0))) + am1 = np.argmax(np.abs(ft_data * (fy < 0))) i1y = am1 % oy i1x = int((am1 - i1y) / oy) return fx[i1x, 0], fy[0, i1y] -def rotate_noreshape(arr, angle, mode="mirror", reshape=False): +def rotate_noreshape( + arr: np.ndarray, angle: float, axes: tuple[int, ...], + mode: str = "mirror", reshape: bool = False) -> np.ndarray: return scipy.ndimage.rotate( arr, # input angle=np.rad2deg(angle), # angle + axes=axes, reshape=reshape, # reshape order=0, # order mode=mode, # mode diff --git a/qpretrieve/utils.py b/qpretrieve/utils.py new file mode 100644 index 0000000..09307f6 --- /dev/null +++ b/qpretrieve/utils.py @@ -0,0 +1,47 @@ +import numpy as np + + +def _mean_2d(data): + """Exists for testing against mean_3d""" + data -= data.mean() + return data + + +def mean_3d(data: np.ndarray) -> np.ndarray: + """ + Subtract mean inplace from 3D array `data` for every + 2D array along z axis. + """ + # The mean array here is (1000,), so we need to add newaxes for subtraction + # (1000, 5, 5) -= (1000, 1, 1) + data -= data.mean(axis=(-2, -1))[:, np.newaxis, np.newaxis] + return data + + +def _padding_2d(data, order, dtype): + """Exists for testing against padding_3d""" + # this is faster than np.pad + datapad = np.zeros((order, order), dtype=dtype) + # we could of course use np.atleast_3d here + datapad[:data.shape[0], :data.shape[1]] = data + return datapad + + +def padding_3d(data: np.ndarray, size: int, dtype: np.dtype) -> np.ndarray: + """Pad a 3D array in the second and third dimensions (y, x) to `size` + + Parameters + ---------- + data + 3d array. The padding will be applied to the axes (y,x) only. + size + The data will be padded to this size in the (y, x) dimensions. + dtype + data type of the padded array. + + """ + z, y, x = data.shape + # this is faster than np.pad + datapad = np.zeros((z, size, size), dtype=dtype) + datapad[:, :y, :x] = data + return datapad diff --git a/setup.py b/setup.py index 3472257..4d65ad5 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,10 @@ author = u"Paul Müller" -authors = [author] +authors = [ + author, + "Eoghan O'Connell" +] description = 'library for phase retrieval from holograms' name = 'qpretrieve' year = "2022" diff --git a/tests/conftest.py b/tests/conftest.py index 759db0d..21c9837 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,9 +2,11 @@ import shutil import tempfile import time +import numpy as np -import qpretrieve +import pytest +import qpretrieve TMPDIR = tempfile.mkdtemp(prefix=time.strftime( "qpretrieve_test_%H.%M_")) @@ -22,3 +24,23 @@ def pytest_configure(config): # creating FFTW wisdom. Also, it makes the tests more reproducible # by sticking to simple numpy FFTs. qpretrieve.fourier.PREFERRED_INTERFACE = "FFTFilterNumpy" + + +@pytest.fixture(params=[64]) # default param for size +def hologram(request): + size = request.param + x = np.arange(size).reshape(-1, 1) - size / 2 + y = np.arange(size).reshape(1, -1) - size / 2 + + amp = np.linspace(.9, 1.1, size * size).reshape(size, size) + pha = np.linspace(0, 2, size * size).reshape(size, size) + + rad = x ** 2 + y ** 2 > (size / 3) ** 2 + pha[rad] = 0 + amp[rad] = 1 + + # frequencies must match pixel in Fourier space + kx = 2 * np.pi * -.3 + ky = 2 * np.pi * -.3 + image = (amp ** 2 + np.sin(kx * x + ky * y + pha) + 1) * 255 + return image diff --git a/tests/test_array_layout_convert_from_3d.py b/tests/test_array_layout_convert_from_3d.py new file mode 100644 index 0000000..7aa701d --- /dev/null +++ b/tests/test_array_layout_convert_from_3d.py @@ -0,0 +1,55 @@ +import numpy as np +import pytest + +from qpretrieve.data_array_layout import ( + convert_3d_data_to_array_layout, + _convert_3d_to_2d, _convert_3d_to_rgba, _convert_3d_to_rgb, +) + + +def test_convert_3d_data_to_2d(): + data = np.zeros(shape=(10, 256, 256)) + array_layout = "2d" + + data_new = convert_3d_data_to_array_layout(data, array_layout) + data_direct = _convert_3d_to_2d(data) # this is the internal function + + assert np.array_equal(data[0], data_new) + assert data_new.shape == data_direct.shape == (256, 256) + assert np.array_equal(data_direct, data_new) + + +def test_convert_3d_data_to_rgb(): + data = np.zeros(shape=(10, 256, 256)) + array_layout = "rgb" + + data_new = convert_3d_data_to_array_layout(data, array_layout) + data_direct = _convert_3d_to_rgb(data) # this is the internal function + + assert data_new.shape == data_direct.shape == (256, 256, 3) + assert np.array_equal(data_direct, data_new) + + +def test_convert_3d_data_to_rgba(): + data = np.zeros(shape=(10, 256, 256)) + array_layout = "rgba" + + data_new = convert_3d_data_to_array_layout(data, array_layout) + data_direct = _convert_3d_to_rgba(data) # this is the internal function + + assert data_new.shape == data_direct.shape == (256, 256, 4) + assert np.array_equal(data_direct, data_new) + + +def test_convert_3d_data_to_array_layout_bad_input(): + data = np.zeros(shape=(10, 256, 256)) + array_layout = "5d" + + with pytest.raises(AssertionError, match="`array_layout` not allowed."): + convert_3d_data_to_array_layout(data, array_layout) + + data = np.zeros(shape=(256, 256)) + array_layout = "2d" + + with pytest.raises(AssertionError, match="The data should be 3d"): + convert_3d_data_to_array_layout(data, array_layout) diff --git a/tests/test_array_layout_convert_to_3d.py b/tests/test_array_layout_convert_to_3d.py new file mode 100644 index 0000000..0c1a773 --- /dev/null +++ b/tests/test_array_layout_convert_to_3d.py @@ -0,0 +1,43 @@ +import numpy as np + +from qpretrieve.data_array_layout import convert_data_to_3d_array_layout + + +def test_check_data_input_2d(): + data = np.zeros(shape=(256, 256)) + + data_new, orig_array_layout = convert_data_to_3d_array_layout(data) + + assert data_new.shape == (1, 256, 256) + assert np.array_equal(data_new[0], data) + assert orig_array_layout == "2d" + + +def test_check_data_input_3d_image_stack(): + data = np.zeros(shape=(50, 256, 256)) + + data_new, orig_array_layout = convert_data_to_3d_array_layout(data) + + assert data_new.shape == (50, 256, 256) + assert np.array_equal(data_new, data) + assert orig_array_layout == "3d" + + +def test_check_data_input_3d_rgb(): + data = np.zeros(shape=(256, 256, 3)) + + data_new, orig_array_layout = convert_data_to_3d_array_layout(data) + + assert data_new.shape == (1, 256, 256) + assert np.array_equal(data_new[0], data[:, :, 0]) + assert orig_array_layout == "rgb" + + +def test_check_data_input_3d_rgba(): + data = np.zeros(shape=(256, 256, 4)) + + data_new, orig_array_layout = convert_data_to_3d_array_layout(data) + + assert data_new.shape == (1, 256, 256) + assert np.array_equal(data_new[0], data[:, :, 0]) + assert orig_array_layout == "rgba" diff --git a/tests/test_fourier_base.py b/tests/test_fourier_base.py index c364ebb..d73d04c 100644 --- a/tests/test_fourier_base.py +++ b/tests/test_fourier_base.py @@ -1,3 +1,4 @@ +import pytest import copy import pathlib @@ -16,7 +17,7 @@ def test_scale_sanity_check(): # create a 2D gaussian test image x = np.linspace(-100, 100, 100) xx, yy = np.meshgrid(x, -x, indexing="ij") - gauss = np.exp(-(xx**2 + yy**2) / 625) + gauss = np.exp(-(xx ** 2 + yy ** 2) / 625) ft = fourier.FFTFilterNumpy(gauss, subtract_mean=False) @@ -106,7 +107,7 @@ def test_scale_to_filter_oah(): ifr.run_pipeline() phase = unwrap_phase(ifh.phase - ifr.phase) - assert phase.shape == (200, 210) + assert phase.shape == (1, 200, 210) assert np.allclose(phase.mean(), 1.0840394954441188, atol=1e-5) # Rescaled pipeline @@ -115,10 +116,24 @@ def test_scale_to_filter_oah(): ifh.run_pipeline(**pipeline_kws_scale) ifr.run_pipeline(**pipeline_kws_scale) phase_scaled = unwrap_phase(ifh.phase - ifr.phase) - assert phase_scaled.shape == (33, 34) + assert phase_scaled.shape == (1, 33, 34) assert np.allclose(phase_scaled.mean(), 1.0469570087033453, atol=1e-5) +def test_bad_fft_interface_input(): + """Fails because inputting PyFFTW without installing defaults to None""" + data = np.load(data_path / "hologram_cell.npz") + image = data["data"] + + with pytest.raises( + interfere.BadFFTFilterError, + match="`fft_interface` is set to None. If you want qpretrieve to " + "find the best FFT interface, set it to 'auto'. " + "If you are trying to use `FFTFilterPyFFTW`, " + "you must first install the pyfftw package."): + interfere.OffAxisHologram(image, fft_interface=None) + + def test_scale_to_filter_qlsi(): with h5py.File(data_path / "qlsi_paa_bead.h5") as h5: image = h5["0"][:] @@ -136,12 +151,23 @@ def test_scale_to_filter_qlsi(): } ifh = interfere.QLSInterferogram(image, **pipeline_kws) - ifh.run_pipeline() + raw_wavefront = ifh.run_pipeline() + assert raw_wavefront.shape == (1, 720, 720) + assert ifh.phase.shape == (1, 720, 720) + assert ifh.amplitude.shape == (1, 720, 720) + assert ifh.field.shape == (1, 720, 720) ifr = interfere.QLSInterferogram(refer, **pipeline_kws) ifr.run_pipeline() + assert ifr.phase.shape == (1, 720, 720) + assert ifr.amplitude.shape == (1, 720, 720) + assert ifr.field.shape == (1, 720, 720) + + ifh_phase = ifh.phase[0] + ifr_phase = ifr.phase[0] + + phase = unwrap_phase(ifh_phase - ifr_phase) - phase = unwrap_phase(ifh.phase - ifr.phase) assert phase.shape == (720, 720) assert np.allclose(phase.mean(), 0.12434563269684816, atol=1e-6) @@ -152,6 +178,93 @@ def test_scale_to_filter_qlsi(): ifr.run_pipeline(**pipeline_kws_scale) phase_scaled = unwrap_phase(ifh.phase - ifr.phase) - assert phase_scaled.shape == (126, 126) + assert phase_scaled.shape == (1, 126, 126) assert np.allclose(phase_scaled.mean(), 0.1257080793074251, atol=1e-6) + + +def test_fft_dimensionality_consistency(): + """Compare using fft algorithms on 2d and 3d data.""" + image_3d = np.arange(1000).reshape(10, 10, 10) + image_2d = image_3d[0].copy() + + # fft with shift + fft_3d = np.fft.fftshift(np.fft.fft2(image_3d, axes=(-2, -1)), + axes=(-2, -1)) + fft_2d = np.fft.fftshift(np.fft.fft2(image_2d)) # old qpretrieve + assert fft_3d.shape == (10, 10, 10) + assert fft_2d.shape == (10, 10) + assert np.allclose(fft_3d[0], fft_2d, rtol=0, atol=1e-8) + + # ifftshift + fft_3d_shifted = np.fft.ifftshift(fft_3d, axes=(-2, -1)) + fft_2d_shifted = np.fft.ifftshift(fft_2d) # old qpretrieve + assert fft_3d_shifted.shape == (10, 10, 10) + assert fft_2d_shifted.shape == (10, 10) + assert np.allclose(fft_3d_shifted[0], fft_2d_shifted, rtol=0, atol=1e-8) + + # ifft + ifft_3d_shifted = np.fft.ifft2(fft_3d_shifted, axes=(-2, -1)) + ifft_2d_shifted = np.fft.ifft2(fft_2d_shifted) # old qpretrieve + assert ifft_3d_shifted.shape == (10, 10, 10) + assert ifft_2d_shifted.shape == (10, 10) + assert np.allclose(ifft_3d_shifted[0], ifft_2d_shifted, rtol=0, atol=1e-8) + + assert np.allclose(ifft_3d_shifted.real, image_3d, rtol=0, atol=1e-8) + assert np.allclose(ifft_2d_shifted.real, image_2d, rtol=0, atol=1e-8) + + +def test_fft_comparison_FFTFilter(): + image = np.arange(1000).reshape(10, 10, 10) + ff_np = fourier.FFTFilterNumpy(image, subtract_mean=False, padding=False) + ff_tw = fourier.FFTFilterPyFFTW(image, subtract_mean=False, padding=False) + assert ff_np.fft_origin.shape == ff_tw.fft_origin.shape == (10, 10, 10) + + assert np.allclose(ff_np.fft_origin, ff_tw.fft_origin, rtol=0, atol=1e-8) + assert np.allclose( + np.fft.ifft2(np.fft.ifftshift(ff_np.fft_origin, axes=(-2, -1))).real, + np.fft.ifft2(np.fft.ifftshift(ff_tw.fft_origin, axes=(-2, -1))).real, + rtol=0, + atol=1e-8 + ) + + +def test_fft_comparison_data_input_fmt(): + image = np.arange(1000).reshape(10, 10, 10) + FFTFilters = [fourier.FFTFilterNumpy, fourier.FFTFilterPyFFTW] + + for fftfilt in FFTFilters: + # 3d input + ff_3d = fftfilt(image, subtract_mean=False, padding=False) + # 2d input + ff_arr_2d = np.zeros_like(ff_3d.fft_origin) + for i, img in enumerate(image): + ff_2d = fftfilt(img, subtract_mean=False, padding=False) + ff_arr_2d[i] = ff_2d.fft_origin + + # ffts are the same + assert np.allclose(ff_2d.fft_origin, + ff_3d.fft_origin[i], + rtol=0, atol=1e-8) + # iffts are the same + assert np.allclose(np.fft.ifft2(ff_2d.fft_origin).real, + np.fft.ifft2(ff_3d.fft_origin[i]).real, + rtol=0, atol=1e-8) + # shifted iffts are the same, if you use arg axes + assert np.allclose( + np.fft.ifft2(np.fft.ifftshift( + ff_2d.fft_origin, axes=(-2, -1))).real, + np.fft.ifft2(np.fft.ifftshift( + ff_3d.fft_origin[i], axes=(-2, -1))).real, + rtol=0, atol=1e-8) + # shifted 2d ifft is the same as the 2d img + assert np.allclose( + np.fft.ifft2(np.fft.ifftshift( + ff_2d.fft_origin, axes=(-2, -1))).real, + img, rtol=0, atol=1e-8) + + # shifted 3d ifft is the same as the 3d img + assert np.allclose( + np.fft.ifft2(np.fft.ifftshift(ff_3d.fft_origin, + axes=(-2, -1))).real, + image, rtol=0, atol=1e-8) diff --git a/tests/test_fourier_numpy.py b/tests/test_fourier_numpy.py index c95cc69..af161a8 100644 --- a/tests/test_fourier_numpy.py +++ b/tests/test_fourier_numpy.py @@ -3,12 +3,53 @@ from qpretrieve import fourier -def test_fft_correct(): +def test_fft_correct_input_2d(): image = np.arange(100).reshape(10, 10) ff = fourier.FFTFilterNumpy(image, subtract_mean=False, padding=False) + assert ff.fft_origin.shape == (1, 10, 10) assert np.allclose( - np.fft.ifft2(np.fft.ifftshift(ff.fft_origin)).real, + np.fft.ifft2(np.fft.ifftshift(ff.fft_origin, axes=(-2, -1))).real, image, rtol=0, atol=1e-8 ) + + +def test_fft_correct_input_3d(): + image = np.arange(1000).reshape(10, 10, 10) + ff = fourier.FFTFilterNumpy(image, subtract_mean=False, padding=False) + assert ff.fft_origin.shape == (10, 10, 10) + assert np.allclose( + np.fft.ifft2(np.fft.ifftshift(ff.fft_origin, axes=(-2, -1))).real, + image, + rtol=0, + atol=1e-8 + ) + + +def test_fft_correct_input_rgb(): + image = np.arange(300).reshape(10, 10, 3) + ff = fourier.FFTFilterNumpy(image, subtract_mean=False, padding=False) + # does the same as `data_input._convert_rgb_to_3d` + expected_image = image[:, :, 0][np.newaxis, :, :].copy() + assert ff.fft_origin.shape == (1, 10, 10) + assert np.allclose( + np.fft.ifft2(np.fft.ifftshift(ff.fft_origin, axes=(-2, -1))).real, + expected_image, + rtol=0, + atol=1e-8 + ) + + +def test_fft_correct_input_rgba(): + image = np.arange(400).reshape(10, 10, 4) + ff = fourier.FFTFilterNumpy(image, subtract_mean=False, padding=False) + # does the same as `data_input._convert_rgb_to_3d` + expected_image = image[:, :, 0][np.newaxis, :, :].copy() + assert ff.fft_origin.shape == (1, 10, 10) + assert np.allclose( + np.fft.ifft2(np.fft.ifftshift(ff.fft_origin, axes=(-2, -1))).real, + expected_image, + rtol=0, + atol=1e-8 + ) diff --git a/tests/test_fourier_pyfftw.py b/tests/test_fourier_pyfftw.py new file mode 100644 index 0000000..9698dfa --- /dev/null +++ b/tests/test_fourier_pyfftw.py @@ -0,0 +1,55 @@ +import numpy as np + +from qpretrieve import fourier + + +def test_fft_correct_input_2d(): + image = np.arange(100).reshape(10, 10) + ff = fourier.FFTFilterPyFFTW(image, subtract_mean=False, padding=False) + assert ff.fft_origin.shape == (1, 10, 10) + assert np.allclose( + np.fft.ifft2(np.fft.ifftshift(ff.fft_origin, axes=(-2, -1))).real, + image, + rtol=0, + atol=1e-8 + ) + + +def test_fft_correct_input_3d(): + image = np.arange(1000).reshape(10, 10, 10) + ff = fourier.FFTFilterPyFFTW(image, subtract_mean=False, padding=False) + assert ff.fft_origin.shape == (10, 10, 10) + assert np.allclose( + np.fft.ifft2(np.fft.ifftshift(ff.fft_origin, axes=(-2, -1))).real, + image, + rtol=0, + atol=1e-8 + ) + + +def test_fft_correct_input_rgb(): + image = np.arange(300).reshape(10, 10, 3) + ff = fourier.FFTFilterPyFFTW(image, subtract_mean=False, padding=False) + # does the same as `data_input._convert_rgb_to_3d` + expected_image = image[:, :, 0][np.newaxis, :, :].copy() + assert ff.fft_origin.shape == (1, 10, 10) + assert np.allclose( + np.fft.ifft2(np.fft.ifftshift(ff.fft_origin, axes=(-2, -1))).real, + expected_image, + rtol=0, + atol=1e-8 + ) + + +def test_fft_correct_input_rgba(): + image = np.arange(400).reshape(10, 10, 4) + ff = fourier.FFTFilterPyFFTW(image, subtract_mean=False, padding=False) + # does the same as `data_input._convert_rgb_to_3d` + expected_image = image[:, :, 0][np.newaxis, :, :].copy() + assert ff.fft_origin.shape == (1, 10, 10) + assert np.allclose( + np.fft.ifft2(np.fft.ifftshift(ff.fft_origin, axes=(-2, -1))).real, + expected_image, + rtol=0, + atol=1e-8 + ) diff --git a/tests/test_interfere_base.py b/tests/test_interfere_base.py new file mode 100644 index 0000000..1428e04 --- /dev/null +++ b/tests/test_interfere_base.py @@ -0,0 +1,148 @@ +import pathlib +import numpy as np +import pytest + +import qpretrieve + +data_path = pathlib.Path(__file__).parent / "data" + + +def test_interfere_base_best_interface(): + edata = np.load(data_path / "hologram_cell.npz") + + holo = qpretrieve.OffAxisHologram(data=edata["data"]) + assert holo.ff_iface.is_available + assert issubclass(holo.ff_iface, + qpretrieve.fourier.base.FFTFilter) + assert issubclass(holo.ff_iface, + qpretrieve.fourier.ff_numpy.FFTFilterNumpy) + + +def test_interfere_base_choose_interface(): + edata = np.load(data_path / "hologram_cell.npz") + + for InterferCls in [qpretrieve.OffAxisHologram, + qpretrieve.QLSInterferogram]: + interfer_inst = InterferCls( + data=edata["data"], + fft_interface=qpretrieve.fourier.FFTFilterNumpy) + assert interfer_inst.ff_iface.is_available + assert issubclass(interfer_inst.ff_iface, + qpretrieve.fourier.base.FFTFilter) + assert issubclass(interfer_inst.ff_iface, + qpretrieve.fourier.ff_numpy.FFTFilterNumpy) + + +def test_interfere_base_bad_interface(): + edata = np.load(data_path / "hologram_cell.npz") + bad_name = "MyReallyCoolFFTInterface" + + with pytest.raises( + qpretrieve.interfere.BadFFTFilterError, + match=f"User-chosen FFT Interface '{bad_name}' is not available."): + _ = qpretrieve.OffAxisHologram( + data=edata["data"], + fft_interface=bad_name) + + +def test_interfere_base_orig_array_layout(): + edata = np.load(data_path / "hologram_cell.npz") + + holo = qpretrieve.OffAxisHologram(data=edata["data"]) + assert holo.orig_array_layout is not None + assert holo.orig_array_layout == "2d" + + +def test_interfere_base_get_data_with_input_layout(): + edata = np.load(data_path / "hologram_cell.npz") + orig_shape = (200, 210) + assert edata["data"].shape == orig_shape + + holo = qpretrieve.OffAxisHologram(data=edata["data"]) + assert holo.field.shape == (1, 200, 210) + + field_orig1 = holo.get_data_with_input_layout(data=holo.field) + field_orig2 = holo.get_data_with_input_layout(data="field") + + assert field_orig1.shape == field_orig2.shape == orig_shape + + +def test_interfere_base_get_data_with_input_layout_fft_warning(): + edata = np.load(data_path / "hologram_cell.npz") + orig_shape = (200, 210) + assert edata["data"].shape == orig_shape + + holo = qpretrieve.OffAxisHologram(data=edata["data"]) + assert holo.field.shape == (1, 200, 210) + + fft_orig1 = holo.get_data_with_input_layout(data="fft_filtered") + fft_orig2 = holo.get_data_with_input_layout(data="fft") + assert fft_orig1.shape == fft_orig2.shape + + +def test_get_data_with_input_layout_2d(hologram): + """The original data format should be returned correctly""" + data_2d = hologram + expected_output_shape = (1, data_2d.shape[-2], data_2d.shape[-1]) + + # 2d data format + oah = qpretrieve.OffAxisHologram(data_2d, padding=False, + subtract_mean=False) + res = oah.run_pipeline() + assert res.shape == expected_output_shape + + data_attrs = [oah.field, oah.fft_origin, oah.fft_filtered, + oah.amplitude, oah.phase, + "field", "fft_origin", "fft_filtered", + "amplitude", "phase"] + for data_attr in data_attrs: + if not isinstance(data_attr, str): + assert data_attr.shape == expected_output_shape + # original shape was 2d + orig_data = oah.get_data_with_input_layout(data_attr) + assert orig_data.shape == data_2d.shape + + +def test_get_data_with_input_layout_rgb(hologram): + """The original data format should be returned correctly""" + data_rgb = np.stack([hologram, hologram, hologram], axis=-1) + expected_output_shape = (1, hologram.shape[-2], hologram.shape[-1]) + + # 2d data format + oah = qpretrieve.OffAxisHologram(data_rgb, padding=False, + subtract_mean=False) + _ = oah.run_pipeline() + + data_attrs = [oah.field, oah.fft_origin, oah.fft_filtered, + oah.amplitude, oah.phase, + "field", "fft_origin", "fft_filtered", + "amplitude", "phase"] + for data_attr in data_attrs: + if not isinstance(data_attr, str): + assert data_attr.shape == expected_output_shape + # original shape was 2d + assert oah.get_data_with_input_layout( + data_attr).shape == data_rgb.shape + + +def test_get_data_with_input_layout_rgba(hologram): + """The original data format should be returned correctly""" + data_rgba = np.stack([hologram, hologram, hologram, + np.zeros_like(hologram)], axis=-1) + expected_output_shape = (1, hologram.shape[-2], hologram.shape[-1]) + + # 2d data format + oah = qpretrieve.OffAxisHologram(data_rgba, padding=False, + subtract_mean=False) + _ = oah.run_pipeline() + + data_attrs = [oah.field, oah.fft_origin, oah.fft_filtered, + oah.amplitude, oah.phase, + "field", "fft_origin", "fft_filtered", + "amplitude", "phase"] + for data_attr in data_attrs: + if not isinstance(data_attr, str): + assert data_attr.shape == expected_output_shape + # original shape was 2d + assert oah.get_data_with_input_layout( + data_attr).shape == data_rgba.shape diff --git a/tests/test_oah.py b/tests/test_oah.py new file mode 100644 index 0000000..4704978 --- /dev/null +++ b/tests/test_oah.py @@ -0,0 +1,323 @@ +import numpy as np +import pytest +import pathlib + +import qpretrieve +from qpretrieve.interfere import if_oah +from qpretrieve.fourier import FFTFilterNumpy, FFTFilterPyFFTW +from qpretrieve.data_array_layout import ( + convert_data_to_3d_array_layout, + _convert_2d_to_3d, _convert_3d_to_rgb, _convert_3d_to_rgba, +) + +data_path = pathlib.Path(__file__).parent / "data" + + +def test_find_sideband(): + size = 40 + ft_data = np.zeros((size, size)) + fx = np.fft.fftshift(np.fft.fftfreq(size)) + ft_data[2, 3] = 1 + ft_data[-3, -2] = 1 + + sb1 = if_oah.find_peak_cosine(ft_data=ft_data) + assert np.allclose(sb1, (fx[2], fx[3])) + + +def test_fourier2dpad(): + y, x = 100, 120 + data = np.zeros((y, x)) + fft1 = qpretrieve.fourier.FFTFilterNumpy(data) + assert fft1.shape == (1, 256, 256) + + fft2 = qpretrieve.fourier.FFTFilterNumpy(data, padding=False) + assert fft2.shape == (1, y, x) + + +def test_get_field_error_bad_filter_size(hologram): + data = hologram + + holo = qpretrieve.OffAxisHologram(data) + with pytest.raises(ValueError, match="must be between 0 and 1"): + holo.run_pipeline(filter_size=2) + + +def test_get_field_error_bad_filter_size_interpretation_frequency_index( + hologram): + data = hologram + holo = qpretrieve.OffAxisHologram(data) + + with pytest.raises(ValueError, + match=r"must be between 0 and 64"): + holo.run_pipeline(filter_size_interpretation="frequency index", + filter_size=64) + + +def test_get_field_error_invalid_interpretation(hologram): + data = hologram + holo = qpretrieve.OffAxisHologram(data) + + with pytest.raises(ValueError, + match="Invalid value for `filter_size_interpretation`"): + holo.run_pipeline(filter_size_interpretation="blequency") + + +def test_get_field_filter_names(hologram): + data_2d = hologram + data_3d, _ = convert_data_to_3d_array_layout(data_2d) + data_3d = np.repeat(data_3d, repeats=10, axis=0) + kwargs = dict(sideband=+1, + filter_size=1 / 3) + bad_filter_name = "sepia" + + for data in (data_2d, data_3d): + holo = qpretrieve.OffAxisHologram(data) + + r_disk = holo.run_pipeline(filter_name="disk", **kwargs) + r_smdisk = holo.run_pipeline(filter_name="smooth disk", **kwargs) + r_gauss = holo.run_pipeline(filter_name="gauss", **kwargs) + r_square = holo.run_pipeline(filter_name="square", **kwargs) + r_smsquare = holo.run_pipeline(filter_name="smooth square", **kwargs) + r_tukey = holo.run_pipeline(filter_name="tukey", **kwargs) + + for i in range(len(r_disk)): + assert np.allclose( + r_disk[i, 32, 32], 97.307780444912936 - 76.397860381241372j) + assert np.allclose( + r_smdisk[i, 32, 32], 108.36438759594623 - 67.1806221692573j) + assert np.allclose( + r_gauss[i, 32, 32], 108.2914187451138 - 67.1823527237741j) + assert np.allclose( + r_square[i, 32, 32], 102.3285348843612 - 74.139058665601155j) + assert np.allclose( + r_smsquare[i, 32, 32], 108.36651862466393 - 67.17988960794392j) + assert np.allclose( + r_tukey[i, 32, 32], 113.4826495540899 - 59.546232775481869j) + + with pytest.raises(ValueError, match=f"Unknown filter: {bad_filter_name}"): + holo.run_pipeline(filter_name=bad_filter_name, **kwargs) + + +@pytest.mark.parametrize("hologram", [62, 63, 64], indirect=True) +def test_get_field_interpretation_fourier_index(hologram): + """Filter size in Fourier space using Fourier index new in 0.7.0""" + data = hologram + shape_expected = (1, hologram.shape[-2], hologram.shape[-1]) + holo = qpretrieve.OffAxisHologram(data) + + ft_data = holo.fft_origin + holo.run_pipeline() + fsx, fsy = holo.pipeline_kws["sideband_freq"] + + kwargs1 = dict(filter_name="disk", + filter_size=1 / 3, + filter_size_interpretation="sideband distance") + res1 = holo.run_pipeline(**kwargs1) + + filter_size_fi = np.sqrt(fsx ** 2 + fsy ** 2) / 3 * ft_data.shape[-2] + kwargs2 = dict(filter_name="disk", + filter_size=filter_size_fi, + filter_size_interpretation="frequency index", + ) + res2 = holo.run_pipeline(**kwargs2) + + assert res1.shape == shape_expected + assert res2.shape == shape_expected + assert np.all(res1 == res2) + + +@pytest.mark.parametrize("hologram", [62, 63, 64], indirect=["hologram"]) +def test_get_field_interpretation_fourier_index_control(hologram): + """Filter size in Fourier space using Fourier index new in 0.7.0""" + data = hologram + holo = qpretrieve.OffAxisHologram(data) + + ft_data = holo.fft_origin + holo.run_pipeline() + fsx, fsy = holo.pipeline_kws["sideband_freq"] + + evil_factor = 1.1 + + kwargs1 = dict(filter_name="disk", + filter_size=1 / 3 * evil_factor, + filter_size_interpretation="sideband distance" + ) + res1 = holo.run_pipeline(**kwargs1) + + filter_size_fi = np.sqrt(fsx ** 2 + fsy ** 2) / 3 * ft_data.shape[-2] + kwargs2 = dict(filter_name="disk", + filter_size=filter_size_fi, + filter_size_interpretation="frequency index", + ) + res2 = holo.run_pipeline(**kwargs2) + assert not np.all(res1 == res2) + + +@pytest.mark.parametrize("hologram", [62, 63, 64, 134, 135], + indirect=["hologram"]) +@pytest.mark.parametrize("filter_size", [17, 17.01]) +def test_get_field_interpretation_fourier_index_mask_1(hologram, filter_size): + """Make sure filter size in Fourier space pixels is correct""" + data = hologram + holo = qpretrieve.OffAxisHologram(data) + + kwargs2 = dict(filter_name="disk", + filter_size=filter_size, + filter_size_interpretation="frequency index", + ) + holo.run_pipeline(**kwargs2) + mask = holo.fft_filtered.real != 0 + + # We get 17*2+1, because we measure from the center of Fourier + # space and a pixel is included if its center is withing the + # perimeter of the disk. + assert np.sum(np.sum(mask, axis=-2) != 0) == 17 * 2 + 1 + + +@pytest.mark.parametrize("hologram", [62, 63, 64, 134, 135], + indirect=["hologram"]) +def test_get_field_interpretation_fourier_index_mask_2(hologram): + """Filter size in Fourier space using Fourier index new in 0.7.0""" + data = hologram + holo = qpretrieve.OffAxisHologram(data) + + kwargs2 = dict(filter_name="disk", + filter_size=16.99, + filter_size_interpretation="frequency index" + ) + holo.run_pipeline(**kwargs2) + mask = holo.fft_filtered.real != 0 + + # We get two points less than in the previous test, because we + # loose on each side of the spectrum. + assert np.sum(np.sum(mask, axis=-2) != 0) == 17 * 2 - 1 + + +def test_get_field_int_copy(hologram): + data = hologram + data = np.array(data, dtype=int) + + kwargs = dict(filter_size=1 / 3) + + holo1 = qpretrieve.OffAxisHologram(data, copy=False) + res1 = holo1.run_pipeline(**kwargs) + + holo2 = qpretrieve.OffAxisHologram(data, copy=True) + res2 = holo2.run_pipeline(**kwargs) + + holo3 = qpretrieve.OffAxisHologram(data.astype(float), copy=True) + res3 = holo3.run_pipeline(**kwargs) + + assert np.all(res1 == res2) + assert np.all(res1 == res3) + + +def test_get_field_sideband(hologram): + data = hologram + holo = qpretrieve.OffAxisHologram(data) + holo.run_pipeline() + invert_phase = holo.pipeline_kws["invert_phase"] + + kwargs = dict(filter_name="disk", + filter_size=1 / 3) + + res1 = holo.run_pipeline(invert_phase=False, **kwargs) + res2 = holo.run_pipeline(invert_phase=invert_phase, **kwargs) + assert np.all(res1 == res2) + + +def test_get_field_three_axes(hologram): + data1 = hologram + # create a copy with empty entry in third axis + data2 = np.zeros((data1.shape[0], data1.shape[1], 3)) + data2[:, :, 0] = data1 + # both will be output as (z,y,x) shaped image stacks + shape_expected = (1, hologram.shape[-2], hologram.shape[-1]) + + holo1 = qpretrieve.OffAxisHologram(data1) + holo2 = qpretrieve.OffAxisHologram(data2) + + kwargs = dict(filter_name="disk", + filter_size=1 / 3) + res1 = holo1.run_pipeline(**kwargs) + res2 = holo2.run_pipeline(**kwargs) + + assert res1.shape == shape_expected + assert res2.shape == shape_expected + assert np.all(res1 == res2) + + +def test_get_field_compare_FFTFilters(hologram): + data1 = hologram + kwargs = dict(filter_name="disk", filter_size=1 / 3) + padding = False + shape_expected = (1, hologram.shape[-2], hologram.shape[-1]) + + holo1 = qpretrieve.OffAxisHologram(data1, + fft_interface=FFTFilterNumpy, + padding=padding) + res1 = holo1.run_pipeline(**kwargs) + assert res1.shape == shape_expected + + holo2 = qpretrieve.OffAxisHologram(data1, + fft_interface=FFTFilterPyFFTW, + padding=padding) + res2 = holo2.run_pipeline(**kwargs) + assert res2.shape == shape_expected + + # not exactly the same, but roughly equal to 1e-5 + assert np.allclose(holo1.fft.fft_used, holo2.fft.fft_used) + assert np.allclose(res1, res2) + + +def test_field_format_consistency(hologram): + """The data format returned should be (z,y,x)""" + data_2d = hologram.copy() + shape_expected = (1, hologram.shape[-2], hologram.shape[-1]) + + # 2d data format + holo_2d = qpretrieve.OffAxisHologram(data_2d) + res_2d = holo_2d.run_pipeline() + assert res_2d.shape == shape_expected + + # 3d data format + data_3d, _ = _convert_2d_to_3d(data_2d) + holo_3d = qpretrieve.OffAxisHologram(data_3d) + res_3d = holo_3d.run_pipeline() + assert res_3d.shape == shape_expected + + # rgb data format + data_rgb = _convert_3d_to_rgb(data_3d) + holo_rgb = qpretrieve.OffAxisHologram(data_rgb) + res_rgb = holo_rgb.run_pipeline() + assert res_rgb.shape == shape_expected + + # rgba data format + data_rgba = _convert_3d_to_rgba(data_3d) + holo_rgba = qpretrieve.OffAxisHologram(data_rgba) + res_rgba = holo_rgba.run_pipeline() + assert res_rgba.shape == shape_expected + + assert np.all(res_2d == res_3d) + assert np.all(res_2d == res_rgb) + assert np.all(res_2d == res_rgba) + + +def test_oah_2d_vs_3d_processing(): + edata_2d = np.load(data_path / "hologram_cell.npz")["data"] + edata_3d, _ = convert_data_to_3d_array_layout(edata_2d) + edata_3d = np.repeat(edata_3d, repeats=5, axis=0) + + # 3d + holo_3d = qpretrieve.OffAxisHologram(data=edata_3d, padding=True) + fields_3d = holo_3d.run_pipeline() + + # 2d + fields_2d = np.zeros_like(fields_3d) + for i in range(fields_2d.shape[0]): + holo = qpretrieve.OffAxisHologram(data=edata_2d, padding=True) + holo.run_pipeline() + fields_2d[i] = holo.field[0] # there is only 1 per input + + assert fields_2d.shape == fields_3d.shape + assert np.array_equal(fields_2d, fields_3d) diff --git a/tests/test_oah_from_qpimage.py b/tests/test_oah_from_qpimage.py deleted file mode 100644 index 838b180..0000000 --- a/tests/test_oah_from_qpimage.py +++ /dev/null @@ -1,248 +0,0 @@ -"""These are tests from qpimage""" -import numpy as np -import pytest - -import qpretrieve -from qpretrieve.interfere import if_oah - - -def hologram(size=64): - x = np.arange(size).reshape(-1, 1) - size / 2 - y = np.arange(size).reshape(1, -1) - size / 2 - - amp = np.linspace(.9, 1.1, size * size).reshape(size, size) - pha = np.linspace(0, 2, size * size).reshape(size, size) - - rad = x**2 + y**2 > (size / 3)**2 - pha[rad] = 0 - amp[rad] = 1 - - # frequencies must match pixel in Fourier space - kx = 2 * np.pi * -.3 - ky = 2 * np.pi * -.3 - image = (amp**2 + np.sin(kx * x + ky * y + pha) + 1) * 255 - return image - - -def test_find_sideband(): - size = 40 - ft_data = np.zeros((size, size)) - fx = np.fft.fftshift(np.fft.fftfreq(size)) - ft_data[2, 3] = 1 - ft_data[-3, -2] = 1 - - sb1 = if_oah.find_peak_cosine(ft_data=ft_data) - assert np.allclose(sb1, (fx[2], fx[3])) - - -def test_fourier2dpad(): - data = np.zeros((100, 120)) - fft1 = qpretrieve.fourier.FFTFilterNumpy(data) - assert fft1.shape == (256, 256) - - fft2 = qpretrieve.fourier.FFTFilterNumpy(data, padding=False) - assert fft2.shape == data.shape - - -def test_get_field_error_bad_filter_size(): - data = hologram() - - holo = qpretrieve.OffAxisHologram(data) - with pytest.raises(ValueError, match="must be between 0 and 1"): - holo.run_pipeline(filter_size=2) - - -def test_get_field_error_bad_filter_size_interpretation_frequency_index(): - data = hologram(size=64) - holo = qpretrieve.OffAxisHologram(data) - - with pytest.raises(ValueError, - match=r"must be between 0 and 64"): - holo.run_pipeline(filter_size_interpretation="frequency index", - filter_size=64) - - -def test_get_field_error_invalid_interpretation(): - data = hologram() - holo = qpretrieve.OffAxisHologram(data) - - with pytest.raises(ValueError, - match="Invalid value for `filter_size_interpretation`"): - holo.run_pipeline(filter_size_interpretation="blequency") - - -def test_get_field_filter_names(): - data = hologram() - holo = qpretrieve.OffAxisHologram(data) - - kwargs = dict(sideband=+1, - filter_size=1 / 3) - - r_disk = holo.run_pipeline(filter_name="disk", **kwargs) - assert np.allclose( - r_disk[32, 32], 97.307780444912936 - 76.397860381241372j) - - r_smooth_disk = holo.run_pipeline(filter_name="smooth disk", **kwargs) - assert np.allclose(r_smooth_disk[32, 32], - 108.36438759594623-67.1806221692573j) - - r_gauss = holo.run_pipeline(filter_name="gauss", **kwargs) - assert np.allclose(r_gauss[32, 32], - 108.2914187451138-67.1823527237741j) - - r_square = holo.run_pipeline(filter_name="square", **kwargs) - assert np.allclose( - r_square[32, 32], 102.3285348843612 - 74.139058665601155j) - - r_smsquare = holo.run_pipeline(filter_name="smooth square", **kwargs) - assert np.allclose( - r_smsquare[32, 32], 108.36651862466393-67.17988960794392j) - - r_tukey = holo.run_pipeline(filter_name="tukey", **kwargs) - assert np.allclose( - r_tukey[32, 32], 113.4826495540899 - 59.546232775481869j) - - try: - holo.run_pipeline(filter_name="unknown", **kwargs) - except ValueError: - pass - else: - assert False, "unknown filter accepted" - - -@pytest.mark.parametrize("size", [62, 63, 64]) -def test_get_field_interpretation_fourier_index(size): - """Filter size in Fourier space using Fourier index new in 0.7.0""" - data = hologram(size=size) - holo = qpretrieve.OffAxisHologram(data) - - ft_data = holo.fft_origin - holo.run_pipeline() - fsx, fsy = holo.pipeline_kws["sideband_freq"] - - kwargs1 = dict(filter_name="disk", - filter_size=1/3, - filter_size_interpretation="sideband distance") - res1 = holo.run_pipeline(**kwargs1) - - filter_size_fi = np.sqrt(fsx**2 + fsy**2) / 3 * ft_data.shape[0] - kwargs2 = dict(filter_name="disk", - filter_size=filter_size_fi, - filter_size_interpretation="frequency index", - ) - res2 = holo.run_pipeline(**kwargs2) - assert np.all(res1 == res2) - - -@pytest.mark.parametrize("size", [62, 63, 64]) -def test_get_field_interpretation_fourier_index_control(size): - """Filter size in Fourier space using Fourier index new in 0.7.0""" - data = hologram(size=size) - holo = qpretrieve.OffAxisHologram(data) - - ft_data = holo.fft_origin - holo.run_pipeline() - fsx, fsy = holo.pipeline_kws["sideband_freq"] - - evil_factor = 1.1 - - kwargs1 = dict(filter_name="disk", - filter_size=1/3 * evil_factor, - filter_size_interpretation="sideband distance" - ) - res1 = holo.run_pipeline(**kwargs1) - - filter_size_fi = np.sqrt(fsx**2 + fsy**2) / 3 * ft_data.shape[0] - kwargs2 = dict(filter_name="disk", - filter_size=filter_size_fi, - filter_size_interpretation="frequency index", - ) - res2 = holo.run_pipeline(**kwargs2) - assert not np.all(res1 == res2) - - -@pytest.mark.parametrize("size", [62, 63, 64, 134, 135]) -@pytest.mark.parametrize("filter_size", [17, 17.01]) -def test_get_field_interpretation_fourier_index_mask_1(size, filter_size): - """Make sure filter size in Fourier space pixels is correct""" - data = hologram(size=size) - holo = qpretrieve.OffAxisHologram(data) - - kwargs2 = dict(filter_name="disk", - filter_size=filter_size, - filter_size_interpretation="frequency index", - ) - holo.run_pipeline(**kwargs2) - mask = holo.fft_filtered.real != 0 - - # We get 17*2+1, because we measure from the center of Fourier - # space and a pixel is included if its center is withing the - # perimeter of the disk. - assert np.sum(np.sum(mask, axis=0) != 0) == 17*2 + 1 - - -@pytest.mark.parametrize("size", [62, 63, 64, 134, 135]) -def test_get_field_interpretation_fourier_index_mask_2(size): - """Filter size in Fourier space using Fourier index new in 0.7.0""" - data = hologram(size=size) - holo = qpretrieve.OffAxisHologram(data) - - kwargs2 = dict(filter_name="disk", - filter_size=16.99, - filter_size_interpretation="frequency index" - ) - holo.run_pipeline(**kwargs2) - mask = holo.fft_filtered.real != 0 - - # We get two points less than in the previous test, because we - # loose on each side of the spectrum. - assert np.sum(np.sum(mask, axis=0) != 0) == 17*2 - 1 - - -def test_get_field_int_copy(): - data = hologram() - data = np.array(data, dtype=int) - - kwargs = dict(filter_size=1 / 3) - - holo1 = qpretrieve.OffAxisHologram(data, copy=False) - res1 = holo1.run_pipeline(**kwargs) - - holo2 = qpretrieve.OffAxisHologram(data, copy=True) - res2 = holo2.run_pipeline(**kwargs) - - holo3 = qpretrieve.OffAxisHologram(data.astype(float), copy=True) - res3 = holo3.run_pipeline(**kwargs) - - assert np.all(res1 == res2) - assert np.all(res1 == res3) - - -def test_get_field_sideband(): - data = hologram() - holo = qpretrieve.OffAxisHologram(data) - holo.run_pipeline() - invert_phase = holo.pipeline_kws["invert_phase"] - - kwargs = dict(filter_name="disk", - filter_size=1 / 3) - - res1 = holo.run_pipeline(invert_phase=False, **kwargs) - res2 = holo.run_pipeline(invert_phase=invert_phase, **kwargs) - assert np.all(res1 == res2) - - -def test_get_field_three_axes(): - data1 = hologram() - # create a copy with empty entry in third axis - data2 = np.zeros((data1.shape[0], data1.shape[1], 2)) - data2[:, :, 0] = data1 - - holo1 = qpretrieve.OffAxisHologram(data1) - holo2 = qpretrieve.OffAxisHologram(data2) - - kwargs = dict(filter_name="disk", - filter_size=1 / 3) - res1 = holo1.run_pipeline(**kwargs) - res2 = holo2.run_pipeline(**kwargs) - assert np.all(res1 == res2) diff --git a/tests/test_qlsi.py b/tests/test_qlsi.py index f839114..38c813f 100644 --- a/tests/test_qlsi.py +++ b/tests/test_qlsi.py @@ -2,8 +2,12 @@ import h5py import numpy as np -import qpretrieve +from skimage.restoration import unwrap_phase +import qpretrieve +from qpretrieve.data_array_layout import ( + convert_data_to_3d_array_layout +) data_path = pathlib.Path(__file__).parent / "data" @@ -25,7 +29,186 @@ def test_qlsi_phase(): assert qlsi.wavefront.argmax() == 242294 assert np.allclose(qlsi.wavefront.max(), 8.179288852406586e-08, atol=0, rtol=1e-12) - assert qlsi.phase.argmax() == 242294 assert np.allclose(qlsi.phase.max(), 0.9343997734657971, atol=0, rtol=1e-12) + + +def test_qlsi_phase_3d(): + with h5py.File(data_path / "qlsi_paa_bead.h5") as h5: + data_3d, _ = convert_data_to_3d_array_layout(h5["0"][:]) + data_3d = np.repeat(data_3d, repeats=10, axis=0) + assert data_3d.shape == (10, 720, 720) + qlsi = qpretrieve.QLSInterferogram( + data=data_3d, + reference=h5["reference"][:], + filter_name="tukey", + filter_size=180, + filter_size_interpretation="frequency index", + wavelength=h5["0"].attrs["wavelength"], + qlsi_pitch_term=h5["0"].attrs["qlsi_pitch_term"], + ) + qlsi.run_pipeline() + assert qlsi.pipeline_kws["wavelength"] == 550e-9 + assert qlsi.pipeline_kws["qlsi_pitch_term"] == 1.87711e-08 + for wavefront in qlsi.wavefront: + assert wavefront.argmax() == 242294 + assert np.allclose(wavefront.max(), 8.179288852406586e-08, + atol=0, rtol=1e-12) + + for phase in qlsi.phase: + assert phase.argmax() == 242294 + assert np.allclose(qlsi.phase.max(), 0.9343997734657971, + atol=0, rtol=1e-12) + + +def test_qlsi_fftfreq_reshape_2d_3d(hologram): + data_2d = hologram + data_3d, _ = qpretrieve.data_array_layout._convert_2d_to_3d(data_2d) + + fx_2d = np.fft.fftfreq(data_2d.shape[-1]).reshape(-1, 1) + fx_3d = np.fft.fftfreq(data_3d.shape[-1]).reshape(data_3d.shape[0], -1, 1) + + fy_2d = np.fft.fftfreq(data_2d.shape[-2]).reshape(1, -1) + fy_3d = np.fft.fftfreq(data_3d.shape[-2]).reshape(data_3d.shape[0], 1, -1) + + assert np.array_equal(fx_2d, fx_3d[0]) + assert np.array_equal(fy_2d, fy_3d[0]) + + +def test_qlsi_unwrap_phase_2d_3d(): + """ + Check whether `skimage.restoration.unwrap_phase` unwraps 2d + images along the z axis when given a 3d array input. + Answer is no. `unwrap_phase` is designed for to unwrap data + on all axes at once. + """ + with h5py.File(data_path / "qlsi_paa_bead.h5") as h5: + image = h5["0"][:] + + # Standard analysis pipeline + pipeline_kws = { + 'wavelength': 550e-9, + 'qlsi_pitch_term': 1.87711e-08, + 'filter_name': "disk", + 'filter_size': 180, + 'filter_size_interpretation': "frequency index", + 'scale_to_filter': False, + 'invert_phase': False + } + + data_2d = image + data_3d, _ = qpretrieve.data_array_layout._convert_2d_to_3d(data_2d) + + ft_2d = qpretrieve.fourier.FFTFilterNumpy(data_2d, subtract_mean=False) + ft_3d = qpretrieve.fourier.FFTFilterNumpy(data_3d, subtract_mean=False) + + pipeline_kws["sideband_freq"] = qpretrieve.interfere. \ + if_qlsi.find_peaks_qlsi(ft_2d.fft_origin[0]) + + hx_2d = ft_2d.filter(filter_name=pipeline_kws["filter_name"], + filter_size=pipeline_kws["filter_size"], + scale_to_filter=pipeline_kws["scale_to_filter"], + freq_pos=pipeline_kws["sideband_freq"]) + hx_3d = ft_3d.filter(filter_name=pipeline_kws["filter_name"], + filter_size=pipeline_kws["filter_size"], + scale_to_filter=pipeline_kws["scale_to_filter"], + freq_pos=pipeline_kws["sideband_freq"]) + + assert np.array_equal(hx_2d, hx_3d) + + px_2d = unwrap_phase(np.angle(hx_2d[0])) + + px_3d_loop = np.zeros_like(hx_3d) + for i, _hx in enumerate(hx_3d): + px_3d_loop[i] = unwrap_phase(np.angle(_hx)) + + assert np.array_equal(px_2d, px_3d_loop[0]) # this passes + + px_3d = unwrap_phase(np.angle(hx_3d)) # this is not equivalent + assert not np.array_equal(px_2d, px_3d[0]) + + +def test_qlsi_rotate_2d_3d(hologram): + """ + Ensure the old 2d and new 3d rotation is identical. + Note that the hologram is used only as an example input image, + and it is not the correct data type for QLSI. + """ + data_2d = hologram + data_3d, _ = qpretrieve.data_array_layout._convert_2d_to_3d(data_2d) + + rot_2d = qpretrieve.interfere.if_qlsi.rotate_noreshape( + data_2d, + angle=2, + axes=(1, 0), # this was the default used before + reshape=False, + ) + rot_3d = qpretrieve.interfere.if_qlsi.rotate_noreshape( + data_3d, + angle=2, + axes=(-1, -2), # the y and x axes + reshape=False, + ) + rot_3d_2 = qpretrieve.interfere.if_qlsi.rotate_noreshape( + data_3d, + angle=2, + axes=(-2, -1), # the y and x axes + reshape=False, + ) + + assert rot_2d.dtype == rot_3d.dtype + assert np.array_equal(rot_2d, rot_3d[0]) + assert np.array_equal(rot_2d, rot_3d_2[0]) + + +def test_qlsi_pad_2d_3d(hologram): + """ + Ensure the old 2d and new 3d padding is identical. + Note that the hologram is used only as an example input image, + and it is not the correct data type for QLSI. + """ + data_2d = hologram + data_3d, _ = qpretrieve.data_array_layout._convert_2d_to_3d(data_2d) + + sx, sy = data_2d.shape[-2:] + gradpad_2d = np.pad( + data_2d, ((sx // 2, sx // 2), (sy // 2, sy // 2)), + mode="constant", constant_values=0) + gradpad_3d = np.pad( + data_3d, ((0, 0), (sx // 2, sx // 2), (sy // 2, sy // 2)), + mode="constant", constant_values=0) + + assert gradpad_2d.dtype == gradpad_3d.dtype + assert np.array_equal(gradpad_2d, gradpad_3d[0]) + + +def test_fxy_complex_mul(hologram): + """ + Ensure the old 2d and new 3d complex multiplication is identical. + Note that the hologram is used only as an example input image, + and it is not the correct data type for QLSI. + """ + data_2d = hologram + data_3d, _ = qpretrieve.data_array_layout._convert_2d_to_3d(data_2d) + + assert np.array_equal(data_2d, data_3d[0]) + + # 2d + fx_2d = np.fft.fftfreq(data_2d.shape[0]).reshape(-1, 1) + fy_2d = np.fft.fftfreq(data_2d.shape[1]).reshape(1, -1) + fxy_2d = -2 * np.pi * 1j * (fx_2d + 1j * fy_2d) + assert fxy_2d.shape == (64, 64) + fxy_2d[0, 0] = 1 + + # 3d + fx_3d = np.fft.fftfreq(data_3d.shape[-2]).reshape(-1, 1) + fy_3d = np.fft.fftfreq(data_3d.shape[-1]).reshape(1, -1) + fxy_3d = -2 * np.pi * 1j * (fx_3d + 1j * fy_3d) + fxy_3d = np.repeat(fxy_3d[np.newaxis, :, :], + repeats=data_3d.shape[0], axis=0) + assert fxy_3d.shape == (1, 64, 64) + fxy_3d[:, 0, 0] = 1 + + assert np.array_equal(fx_2d, fx_3d) + assert np.array_equal(fxy_2d, fxy_3d[0]) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..d5ccdab --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,44 @@ +import numpy as np + +from qpretrieve.utils import _padding_2d, padding_3d, _mean_2d, mean_3d + + +def test_mean_subtraction(): + data_3d = np.random.rand(1000, 5, 5).astype(np.float32) + ind = 5 + data_2d = data_3d.copy()[ind] + + data_2d = _mean_2d(data_2d) + data_3d = mean_3d(data_3d) + + assert np.array_equal(data_3d[ind], data_2d) + + +def test_mean_subtraction_consistent_2d_3d(): + """Probably a bit too cumbersome, and changes the default 2d pipeline.""" + data_3d = np.random.rand(1000, 5, 5).astype(np.float32) + ind = 5 + data_2d = data_3d.copy()[ind] + + # too cumbersome + data_2d = np.atleast_3d(data_2d) + data_2d = np.swapaxes(np.swapaxes(data_2d, 0, 2), 1, 2) + data_2d -= data_2d.mean(axis=(-2, -1))[:, np.newaxis, np.newaxis] + + data_3d = np.atleast_3d(data_3d.copy()) + data_3d -= data_3d.mean(axis=(-2, -1))[:, np.newaxis, np.newaxis] + + assert np.array_equal(data_3d[ind], data_2d[0]) + + +def test_batch_padding(): + data_3d = np.random.rand(1000, 100, 320).astype(np.float32) + ind = 5 + data_2d = data_3d.copy()[ind] + order = 512 + dtype = float + + data_2d_padded = _padding_2d(data_2d, order, dtype) + data_3d_padded = padding_3d(data_3d, order, dtype) + + assert np.array_equal(data_3d_padded[ind], data_2d_padded)