From 8e1f2925b27140d0a8d3641bd8d375db3c56ce79 Mon Sep 17 00:00:00 2001 From: Talon Chandler Date: Mon, 7 Oct 2024 13:44:56 -0700 Subject: [PATCH] debugging progress --- .../inplane_oriented_thick_pol3d_vector.py | 33 +++++++++++++++---- 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/waveorder/models/inplane_oriented_thick_pol3d_vector.py b/waveorder/models/inplane_oriented_thick_pol3d_vector.py index d6a6f26e..5853c497 100644 --- a/waveorder/models/inplane_oriented_thick_pol3d_vector.py +++ b/waveorder/models/inplane_oriented_thick_pol3d_vector.py @@ -2,6 +2,7 @@ import numpy as np from torch import Tensor +from typing import Literal from torch.nn.functional import avg_pool3d from waveorder import optics, sampling, stokes, util from waveorder.visuals.napari_visuals import add_transfer_function_to_viewer @@ -16,8 +17,8 @@ def generate_test_phantom(zyx_shape): margin=50, ) c00 = yx_star - c2_2 = -torch.sin(2 * yx_theta) * yx_star - c22 = -torch.cos(2 * yx_theta) * yx_star + c2_2 = -torch.sin(2 * yx_theta) * yx_star #torch.zeros_like(c00) + c22 = -torch.cos(2 * yx_theta) * yx_star #torch.zeros_like(c00) # # Put in a center slices of a 3D object center_slice_object = torch.stack((c00, c2_2, c22), dim=0) @@ -117,7 +118,8 @@ def _calculate_wrap_unsafe_transfer_function( swing, scheme=scheme ) - input_jones = torch.tensor([0.0 + 1.0j, 1.0 + 0j]) # circular + input_jones = torch.tensor([0.0 - 1.0j, 1.0 + 0j]) # circular + # input_jones = torch.tensor([0 + 0j, 1 + 0j]) # linear # Calculate frequencies y_frequencies, x_frequencies = util.generate_frequencies( @@ -131,6 +133,7 @@ def _calculate_wrap_unsafe_transfer_function( ) if invert_phase_contrast: z_position_list = torch.flip(z_position_list, dims=(0,)) + z_frequencies = torch.fft.fftfreq(z_total, d=z_pixel_size) # 2D pupils ill_pupil = optics.generate_pupil( @@ -195,13 +198,28 @@ def _calculate_wrap_unsafe_transfer_function( G_3D = torch.abs(torch.fft.ifft(G, dim=-3)) * (-1j) S_3D = torch.fft.ifft(S, dim=-3) + # cleanup + freq_shape = z_position_list.shape + x_frequencies.shape + + z_broadcast = torch.broadcast_to(z_frequencies[:, None, None], freq_shape) + y_broadcast = torch.broadcast_to(y_frequencies[None, :, :], freq_shape) + x_broadcast = torch.broadcast_to(x_frequencies[None, :, :], freq_shape) + + nu_rr = torch.sqrt(z_broadcast**2 + y_broadcast**2 + x_broadcast**2) + wavelength = wavelength_illumination / index_of_refraction_media + nu_max = (17 / 16) / (wavelength) + nu_min = (15 / 16) / (wavelength) + + mask = torch.logical_and(nu_rr < nu_max, nu_rr > nu_min) + + P_3D *= mask + G_3D *= mask + S_3D *= mask # Main part PG_3D = torch.einsum("zyx,ipzyx->ipzyx", P_3D, G_3D) PS_3D = torch.einsum("zyx,jzyx,kzyx->jkzyx", P_3D, S_3D, torch.conj(S_3D)) - PG_3D /= torch.amax(torch.abs(PG_3D)) - PS_3D /= torch.amax(torch.abs(PS_3D)) pg = torch.fft.fftn(PG_3D, dim=(-3, -2, -1)) ps = torch.fft.fftn(PS_3D, dim=(-3, -2, -1)) @@ -267,14 +285,17 @@ def apply_transfer_function( ) szyx_data = torch.fft.ifftn(sZYX_data, dim=(1, 2, 3)) - return (50 * szyx_data) + 0.1 * torch.randn(szyx_data.shape) + return 50 * szyx_data # + 0.1 * torch.randn(szyx_data.shape) def apply_inverse_transfer_function( szyx_data: Tensor, singular_system: tuple[Tensor], intensity_to_stokes_matrix: Tensor, + reconstruction_algorithm: Literal["Tikhonov", "TV"] = "Tikhonov", regularization_strength: float = 1e-3, + TV_rho_strength: float = 1e-3, + TV_iterations: int = 10, ): sZYX_data = torch.fft.fftn(szyx_data, dim=(1, 2, 3))